From ffc68a12b4b3655198b549022eeeb9509edd6476 Mon Sep 17 00:00:00 2001 From: Runyang2 Date: Tue, 24 Mar 2026 15:32:44 -0400 Subject: [PATCH] Add wildfire models, docs, and benchmark real-data utilities Docs source updated. HTML not rebuilt locally because sphinx_design is missing. --- docs/source/datasets/frap_fire_perimeters.rst | 99 ++ docs/source/datasets/geomac_historical.rst | 98 ++ docs/source/datasets/goes_geocolor.rst | 85 ++ docs/source/datasets/goesr_fdcf.rst | 99 ++ docs/source/datasets/hms_smoke.rst | 99 ++ docs/source/datasets/hpwren_weather.rst | 85 ++ docs/source/datasets/hrrr.rst | 86 ++ docs/source/datasets/landscan_population.rst | 86 ++ docs/source/datasets/nasa_gibs.rst | 85 ++ docs/source/datasets/ndfd.rst | 85 ++ docs/source/datasets/nohrsc_snodas.rst | 85 ++ docs/source/datasets/spot_forecast.rst | 87 ++ docs/source/datasets/synoptic_weather.rst | 85 ++ docs/source/datasets/wrc_housing_density.rst | 85 ++ docs/source/modules/models_asufm.rst | 121 +-- docs/source/modules/models_attention_unet.rst | 42 + .../source/modules/models_convgru_trajgru.rst | 42 + docs/source/modules/models_convlstm.rst | 42 + docs/source/modules/models_deep_ensemble.rst | 42 + docs/source/modules/models_deeplabv3p.rst | 42 + docs/source/modules/models_earthfarseer.rst | 42 + docs/source/modules/models_earthformer.rst | 42 + docs/source/modules/models_firecastnet.rst | 112 +-- docs/source/modules/models_firemm_ir.rst | 56 ++ docs/source/modules/models_firepred.rst | 46 + docs/source/modules/models_forefire.rst | 123 +-- ...models_gemini_25_pro_wildfire_prompted.rst | 47 + .../models_internvl3_wildfire_prompted.rst | 47 + docs/source/modules/models_lightgbm.rst | 42 + .../models_llama4_wildfire_prompted.rst | 48 + .../modules/models_logistic_regression.rst | 42 + docs/source/modules/models_mau.rst | 42 + .../modules/models_modis_active_fire_c61.rst | 48 + docs/source/modules/models_predrnn_v2.rst | 42 + .../modules/models_prithvi_burnscars.rst | 55 ++ .../source/modules/models_prithvi_eo_2_tl.rst | 57 ++ docs/source/modules/models_prithvi_wxc.rst | 57 ++ .../models_qwen25_vl_wildfire_prompted.rst | 52 ++ docs/source/modules/models_rainformer.rst | 42 + docs/source/modules/models_random_forest.rst | 42 + docs/source/modules/models_resnet18_unet.rst | 42 + docs/source/modules/models_segformer.rst | 42 + docs/source/modules/models_swin_unet.rst | 42 + docs/source/modules/models_swinlstm.rst | 42 + docs/source/modules/models_tcn.rst | 42 + docs/source/modules/models_ts_satfire.rst | 50 + docs/source/modules/models_unet.rst | 42 + docs/source/modules/models_utae.rst | 42 + .../modules/models_viirs_375m_active_fire.rst | 48 + docs/source/modules/models_vit_segmenter.rst | 42 + .../modules/models_wildfire_forecasting.rst | 111 --- docs/source/modules/models_wildfiregpt.rst | 55 ++ docs/source/modules/models_wrf_sfire.rst | 121 +-- docs/source/modules/models_xgboost.rst | 42 + docs/source/pyhazards_datasets.rst | 476 +++++++++- docs/source/pyhazards_models.rst | 861 +++++++++++++++++- .../benchmark_cards/wildfire_benchmark.yaml | 3 +- pyhazards/benchmarks/__init__.py | 58 ++ .../wildfire_benchmark/REAL_DATA_2024_PLAN.md | 286 ++++++ .../benchmarks/wildfire_benchmark/__init__.py | 47 + .../wildfire_benchmark/adapters/__init__.py | 10 + .../wildfire_benchmark/adapters/base.py | 89 ++ .../wildfire_benchmark/adapters/registry.py | 14 + .../wildfire_benchmark/adapters/synthetic.py | 138 +++ .../wildfire_benchmark/artifacts.py | 149 +++ .../wildfire_benchmark/cache_builder.py | 378 ++++++++ .../benchmarks/wildfire_benchmark/catalog.py | 53 ++ .../wildfire_benchmark/experiment_settings.py | 134 +++ .../benchmarks/wildfire_benchmark/layout.py | 62 ++ .../wildfire_benchmark/real_runner.py | 653 +++++++++++++ .../benchmarks/wildfire_benchmark/runner.py | 158 ++++ .../wildfire_benchmark/cache_2024_v1.yaml | 30 + .../wildfire_benchmark/model_catalog_22.json | 295 ++++++ .../model_catalog_extensions_v1.json | 284 ++++++ .../track_o_2024_real_v1.json | 86 ++ .../wildfire_benchmark/track_o_2024_v1.json | 82 ++ pyhazards/datasets/__init__.py | 16 +- pyhazards/datasets/wildfire/__init__.py | 15 +- .../datasets/wildfire/real_track_o_2024.py | 451 +++++++++ pyhazards/models/__init__.py | 597 +++++++++++- pyhazards/models/_wildfire_benchmark_utils.py | 110 +++ pyhazards/models/asufm.py | 599 +++++++++++- pyhazards/models/attention_unet.py | 432 +++++++++ pyhazards/models/convgru_trajgru.py | 481 ++++++++++ pyhazards/models/convlstm.py | 470 ++++++++++ pyhazards/models/deep_ensemble.py | 464 ++++++++++ pyhazards/models/deeplabv3p.py | 452 +++++++++ pyhazards/models/earthfarseer.py | 453 +++++++++ pyhazards/models/earthformer.py | 448 +++++++++ pyhazards/models/firecastnet.py | 2 +- pyhazards/models/firemm_ir.py | 181 ++++ pyhazards/models/firepred.py | 98 ++ pyhazards/models/forefire.py | 4 +- .../models/gemini_25_pro_wildfire_prompted.py | 39 + .../models/internvl3_wildfire_prompted.py | 39 + pyhazards/models/lightgbm.py | 65 ++ pyhazards/models/llama4_wildfire_prompted.py | 39 + pyhazards/models/logistic_regression.py | 44 + pyhazards/models/mau.py | 512 +++++++++++ pyhazards/models/modis_active_fire_c61.py | 112 +++ pyhazards/models/predrnn_v2.py | 452 +++++++++ pyhazards/models/prithvi_burnscars.py | 119 +++ pyhazards/models/prithvi_eo_2_tl.py | 251 +++++ pyhazards/models/prithvi_wxc.py | 284 ++++++ .../models/qwen25_vl_wildfire_prompted.py | 169 ++++ pyhazards/models/rainformer.py | 459 ++++++++++ pyhazards/models/random_forest.py | 46 + pyhazards/models/resnet18_unet.py | 465 ++++++++++ pyhazards/models/segformer.py | 568 ++++++++++++ pyhazards/models/swin_unet.py | 527 +++++++++++ pyhazards/models/swinlstm.py | 506 ++++++++++ pyhazards/models/tcn.py | 489 ++++++++++ pyhazards/models/ts_satfire.py | 84 ++ pyhazards/models/unet.py | 482 ++++++++++ pyhazards/models/utae.py | 446 +++++++++ pyhazards/models/viirs_375m_active_fire.py | 111 +++ pyhazards/models/vit_segmenter.py | 471 ++++++++++ pyhazards/models/wildfire_aspp.py | 4 +- pyhazards/models/wildfire_forecasting.py | 91 -- pyhazards/models/wildfiregpt.py | 168 ++++ pyhazards/models/wrf_sfire.py | 17 +- pyhazards/models/xgboost.py | 62 ++ scripts/align_wildfire_2024_fuel.py | 33 + scripts/build_wildfire_2024_cache.py | 35 + scripts/run_wildfire_2024_real_baselines.py | 57 ++ scripts/run_wildfire_smoke_batch.py | 58 ++ 126 files changed, 19698 insertions(+), 652 deletions(-) create mode 100644 docs/source/datasets/frap_fire_perimeters.rst create mode 100644 docs/source/datasets/geomac_historical.rst create mode 100644 docs/source/datasets/goes_geocolor.rst create mode 100644 docs/source/datasets/goesr_fdcf.rst create mode 100644 docs/source/datasets/hms_smoke.rst create mode 100644 docs/source/datasets/hpwren_weather.rst create mode 100644 docs/source/datasets/hrrr.rst create mode 100644 docs/source/datasets/landscan_population.rst create mode 100644 docs/source/datasets/nasa_gibs.rst create mode 100644 docs/source/datasets/ndfd.rst create mode 100644 docs/source/datasets/nohrsc_snodas.rst create mode 100644 docs/source/datasets/spot_forecast.rst create mode 100644 docs/source/datasets/synoptic_weather.rst create mode 100644 docs/source/datasets/wrc_housing_density.rst create mode 100644 docs/source/modules/models_attention_unet.rst create mode 100644 docs/source/modules/models_convgru_trajgru.rst create mode 100644 docs/source/modules/models_convlstm.rst create mode 100644 docs/source/modules/models_deep_ensemble.rst create mode 100644 docs/source/modules/models_deeplabv3p.rst create mode 100644 docs/source/modules/models_earthfarseer.rst create mode 100644 docs/source/modules/models_earthformer.rst create mode 100644 docs/source/modules/models_firemm_ir.rst create mode 100644 docs/source/modules/models_firepred.rst create mode 100644 docs/source/modules/models_gemini_25_pro_wildfire_prompted.rst create mode 100644 docs/source/modules/models_internvl3_wildfire_prompted.rst create mode 100644 docs/source/modules/models_lightgbm.rst create mode 100644 docs/source/modules/models_llama4_wildfire_prompted.rst create mode 100644 docs/source/modules/models_logistic_regression.rst create mode 100644 docs/source/modules/models_mau.rst create mode 100644 docs/source/modules/models_modis_active_fire_c61.rst create mode 100644 docs/source/modules/models_predrnn_v2.rst create mode 100644 docs/source/modules/models_prithvi_burnscars.rst create mode 100644 docs/source/modules/models_prithvi_eo_2_tl.rst create mode 100644 docs/source/modules/models_prithvi_wxc.rst create mode 100644 docs/source/modules/models_qwen25_vl_wildfire_prompted.rst create mode 100644 docs/source/modules/models_rainformer.rst create mode 100644 docs/source/modules/models_random_forest.rst create mode 100644 docs/source/modules/models_resnet18_unet.rst create mode 100644 docs/source/modules/models_segformer.rst create mode 100644 docs/source/modules/models_swin_unet.rst create mode 100644 docs/source/modules/models_swinlstm.rst create mode 100644 docs/source/modules/models_tcn.rst create mode 100644 docs/source/modules/models_ts_satfire.rst create mode 100644 docs/source/modules/models_unet.rst create mode 100644 docs/source/modules/models_utae.rst create mode 100644 docs/source/modules/models_viirs_375m_active_fire.rst create mode 100644 docs/source/modules/models_vit_segmenter.rst delete mode 100644 docs/source/modules/models_wildfire_forecasting.rst create mode 100644 docs/source/modules/models_wildfiregpt.rst create mode 100644 docs/source/modules/models_xgboost.rst create mode 100644 pyhazards/benchmarks/wildfire_benchmark/REAL_DATA_2024_PLAN.md create mode 100644 pyhazards/benchmarks/wildfire_benchmark/__init__.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/adapters/__init__.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/adapters/base.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/adapters/registry.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/adapters/synthetic.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/artifacts.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/cache_builder.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/catalog.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/experiment_settings.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/layout.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/real_runner.py create mode 100644 pyhazards/benchmarks/wildfire_benchmark/runner.py create mode 100644 pyhazards/configs/wildfire_benchmark/cache_2024_v1.yaml create mode 100644 pyhazards/configs/wildfire_benchmark/model_catalog_22.json create mode 100644 pyhazards/configs/wildfire_benchmark/model_catalog_extensions_v1.json create mode 100644 pyhazards/configs/wildfire_benchmark/track_o_2024_real_v1.json create mode 100644 pyhazards/configs/wildfire_benchmark/track_o_2024_v1.json create mode 100644 pyhazards/datasets/wildfire/real_track_o_2024.py create mode 100644 pyhazards/models/_wildfire_benchmark_utils.py create mode 100644 pyhazards/models/attention_unet.py create mode 100644 pyhazards/models/convgru_trajgru.py create mode 100644 pyhazards/models/convlstm.py create mode 100644 pyhazards/models/deep_ensemble.py create mode 100644 pyhazards/models/deeplabv3p.py create mode 100644 pyhazards/models/earthfarseer.py create mode 100644 pyhazards/models/earthformer.py create mode 100644 pyhazards/models/firemm_ir.py create mode 100644 pyhazards/models/firepred.py create mode 100644 pyhazards/models/gemini_25_pro_wildfire_prompted.py create mode 100644 pyhazards/models/internvl3_wildfire_prompted.py create mode 100644 pyhazards/models/lightgbm.py create mode 100644 pyhazards/models/llama4_wildfire_prompted.py create mode 100644 pyhazards/models/logistic_regression.py create mode 100644 pyhazards/models/mau.py create mode 100644 pyhazards/models/modis_active_fire_c61.py create mode 100644 pyhazards/models/predrnn_v2.py create mode 100644 pyhazards/models/prithvi_burnscars.py create mode 100644 pyhazards/models/prithvi_eo_2_tl.py create mode 100644 pyhazards/models/prithvi_wxc.py create mode 100644 pyhazards/models/qwen25_vl_wildfire_prompted.py create mode 100644 pyhazards/models/rainformer.py create mode 100644 pyhazards/models/random_forest.py create mode 100644 pyhazards/models/resnet18_unet.py create mode 100644 pyhazards/models/segformer.py create mode 100644 pyhazards/models/swin_unet.py create mode 100644 pyhazards/models/swinlstm.py create mode 100644 pyhazards/models/tcn.py create mode 100644 pyhazards/models/ts_satfire.py create mode 100644 pyhazards/models/unet.py create mode 100644 pyhazards/models/utae.py create mode 100644 pyhazards/models/viirs_375m_active_fire.py create mode 100644 pyhazards/models/vit_segmenter.py delete mode 100644 pyhazards/models/wildfire_forecasting.py create mode 100644 pyhazards/models/wildfiregpt.py create mode 100644 pyhazards/models/xgboost.py create mode 100644 scripts/align_wildfire_2024_fuel.py create mode 100644 scripts/build_wildfire_2024_cache.py create mode 100644 scripts/run_wildfire_2024_real_baselines.py create mode 100644 scripts/run_wildfire_smoke_batch.py diff --git a/docs/source/datasets/frap_fire_perimeters.rst b/docs/source/datasets/frap_fire_perimeters.rst new file mode 100644 index 00000000..f856f250 --- /dev/null +++ b/docs/source/datasets/frap_fire_perimeters.rst @@ -0,0 +1,99 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +FRAP Fire Perimeters +==================== + +California's authoritative historical fire perimeter archive maintained by CAL FIRE FRAP. + +Overview +-------- + +FRAP Fire Perimeters is CAL FIRE's statewide historical perimeter archive for large fires and other mapped wildfire events in California. + +In PyHazards it serves as a regional authoritative perimeter source for wildfire evaluation, event backfilling, and comparison against national incident feeds such as WFIGS or satellite detections such as FIRMS. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - CAL FIRE / Fire and Resource Assessment Program (FRAP) + * - Hazard Family + - Wildfire + * - Source Role + - Historical Perimeters + * - Coverage + - California + * - Geometry + - Vector fire perimeter polygons + * - Spatial Resolution + - Event-level polygon geometries + * - Temporal Resolution + - Event-based historical perimeter archive + * - Update Cadence + - Annual spring releases with new fire-season perimeters + * - Period of Record + - Historical California fire perimeter archive spanning multiple decades + * - Formats + - Shapefile, file geodatabase downloads, and zipped GIS packages + * - Inspection CLI + - ``ogrinfo -so "/home/runyang/ryang/FRAP_Fire_Perimeters/shapefile/California_Fire_Perimeters_(all).shp" "California_Fire_Perimeters_(all)"`` + +Data Characteristics +-------------------- + +- Statewide polygon archive focused on historical fire perimeters. +- More suitable for perimeter validation and retrospective analysis than for near-real-time detection. +- Includes known completeness limitations for older fires and should be interpreted with source caveats in mind. +- Complements national incident feeds by providing California-specific historical depth. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Historical wildfire perimeter validation in California. +- Regional benchmark label curation and retrospective fire footprint analysis. +- Cross-checking incident records against mapped burn extents. + +Access +------ + +Use the links below to access the upstream source or its public documentation. + +- `CAL FIRE FRAP Fire Perimeters `_ +- `CAL FIRE Fire Perimeters metadata `_ + +PyHazards Usage +--------------- + +Use the local shapefile or zipped archive as an external inspection-first source when you need California-specific historical perimeters in wildfire workflows. + +This dataset is currently documented as an external or inspection-first +source rather than a public ``load_dataset(...)`` entrypoint. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +Use the documented inspection path below to validate local files before training or analysis. + +.. code-block:: bash + + ogrinfo -so "/home/runyang/ryang/FRAP_Fire_Perimeters/shapefile/California_Fire_Perimeters_(all).shp" "California_Fire_Perimeters_(all)" + +Notes +----- + +- FRAP is especially useful when you want a California-specific historical perimeter reference in addition to national feeds. +- Local copy detected at ``/home/runyang/ryang/FRAP_Fire_Perimeters``. + +Reference +--------- + +- `CAL FIRE FRAP Fire Perimeters `_. diff --git a/docs/source/datasets/geomac_historical.rst b/docs/source/datasets/geomac_historical.rst new file mode 100644 index 00000000..8437c2f8 --- /dev/null +++ b/docs/source/datasets/geomac_historical.rst @@ -0,0 +1,98 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +GeoMAC Historical +================= + +Historical GeoMAC wildfire perimeters preserved as a legacy U.S. perimeter archive for long-horizon evaluation. + +Overview +-------- + +GeoMAC Historical packages legacy wildfire perimeter archives that predate newer interagency operational feeds. + +In PyHazards it acts as a historical archive source for long-range retrospective wildfire evaluation, especially when you need older national perimeter context before newer incident systems became standard. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - Legacy GeoMAC / USGS-hosted historical archive + * - Hazard Family + - Wildfire + * - Source Role + - Historical Perimeters + * - Coverage + - United States + * - Geometry + - Archived wildfire perimeter polygons + * - Spatial Resolution + - Event-level perimeter geometries + * - Temporal Resolution + - Event-based archive + * - Update Cadence + - Legacy archive; local copy is static + * - Period of Record + - Local archive includes 2000-2018 plus 2019 packages + * - Formats + - ZIP archives containing GIS perimeter products + * - Inspection CLI + - ``unzip -l "/home/runyang/ryang/GeoMAC_Historical/Historic_Geomac_Perimeters_All_Years_2000_2018/Historic_Geomac_Perimeters_All_Years_2000_2018.zip" | head`` + +Data Characteristics +-------------------- + +- Legacy archive rather than a live operational feed. +- Useful for extending historical perimeter coverage when evaluating older wildfire seasons. +- Typically consumed after extraction into standard GIS formats. +- Best paired with newer systems such as WFIGS for post-2019 workflows. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Long-horizon historical wildfire perimeter studies. +- Retrospective perimeter benchmarking across older U.S. wildfire seasons. +- Gap-filling historical archives before newer interagency feeds. + +Access +------ + +Use the links below to access the upstream source or its public documentation. + +- `USGS Data Series 612: GeoMAC wildfire perimeters `_ + +PyHazards Usage +--------------- + +Use the local archives as an external inspection-first source when older U.S. wildfire perimeter history is needed. + +This dataset is currently documented as an external or inspection-first +source rather than a public ``load_dataset(...)`` entrypoint. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +Use the documented inspection path below to validate local files before training or analysis. + +.. code-block:: bash + + unzip -l "/home/runyang/ryang/GeoMAC_Historical/Historic_Geomac_Perimeters_All_Years_2000_2018/Historic_Geomac_Perimeters_All_Years_2000_2018.zip" | head + +Notes +----- + +- GeoMAC Historical is a legacy archive and should be treated as a historical reference rather than a live feed. +- Local copy detected at ``/home/runyang/ryang/GeoMAC_Historical``. + +Reference +--------- + +- `USGS Data Series 612: GeoMAC wildfire perimeters `_. diff --git a/docs/source/datasets/goes_geocolor.rst b/docs/source/datasets/goes_geocolor.rst new file mode 100644 index 00000000..8302f579 --- /dev/null +++ b/docs/source/datasets/goes_geocolor.rst @@ -0,0 +1,85 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +GOES GeoColor +============= + +NOAA GOES-East/West GeoColor imagery source used for visual fire-scene context. + +Overview +-------- + +GOES GeoColor imagery combines visible and infrared channels into an easy-to-interpret geostationary imagery product. + +In PyHazards it acts as wildfire scene context imagery for visual verification, event inspection, and qualitative comparison against fire and smoke products. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - NOAA GOES / CIRA GeoColor imagery services + * - Hazard Family + - Shared Forcing + * - Source Role + - Satellite Imagery Context + * - Coverage + - GOES-East/West views over the Americas + * - Geometry + - Geostationary imagery time series + * - Spatial Resolution + - ABI imagery resolution on the fixed grid + * - Temporal Resolution + - About every 10 minutes + * - Update Cadence + - Continuous ingest as new imagery becomes available + * - Period of Record + - Local copy spans 2017-2026 with GOES-18 subset on disk + * - Formats + - Image products and derived imagery files + * - Inspection CLI + - ``find /home/runyang/ryang/GOES_GeoColor_CIRA -maxdepth 3 -type f | head`` + +Data Characteristics +-------------------- + +- Visual-context imagery rather than direct fire detections. +- Useful for scene interpretation, plume verification, and rapid event review. +- High temporal refresh over the geostationary domain. +- Best paired with GOES-R FDCF or HMS smoke products. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Visual wildfire scene context. +- Smoke and plume inspection. +- Manual event triage and QA. + +Access +------ + +- `CIRA Slider `_ + +PyHazards Usage +--------------- + +Use this imagery archive as an inspection-first visual context source. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/GOES_GeoColor_CIRA -maxdepth 3 -type f | head + +Notes +----- + +- Local copy detected at ``/home/runyang/ryang/GOES_GeoColor_CIRA``. diff --git a/docs/source/datasets/goesr_fdcf.rst b/docs/source/datasets/goesr_fdcf.rst new file mode 100644 index 00000000..e24ac806 --- /dev/null +++ b/docs/source/datasets/goesr_fdcf.rst @@ -0,0 +1,99 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +GOES-R FDCF +=========== + +GOES-R ABI Fire/Hot Spot Characterization files used for high-frequency active-fire monitoring across the Americas. + +Overview +-------- + +GOES-R FDCF is the ABI Fire/Hot Spot Characterization product from the GOES-R series, providing rapid-refresh geostationary active-fire and hot-spot information. + +In PyHazards it serves as a wildfire-specific geostationary fire-monitoring source that complements FIRMS with much higher refresh frequency over the GOES-East and GOES-West domains. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - NOAA GOES-R Series / ABI + * - Hazard Family + - Wildfire + * - Source Role + - Geostationary Active Fire + * - Coverage + - GOES-East and GOES-West full-disk views over the Americas + * - Geometry + - Geostationary raster NetCDF time series + * - Spatial Resolution + - Product pixels at geostationary ABI resolution (roughly kilometer-scale at nadir) + * - Temporal Resolution + - About every 10 minutes for full-disk scans + * - Update Cadence + - Continuous operational production as new scans arrive + * - Period of Record + - GOES-16 and GOES-18 operational era + * - Formats + - NetCDF + * - Inspection CLI + - ``python -m pyhazards.datasets.goesr.inspection --path /home/runyang/ryang/GOES_FDCF_G16/2024 --max-items 10`` + +Data Characteristics +-------------------- + +- Geostationary fire monitoring with much higher temporal refresh than polar-orbiting active-fire products. +- Product is especially useful for tracking rapidly evolving wildfire activity. +- Domain is regional rather than global, tied to GOES-East and GOES-West views. +- Best used alongside FIRMS, incident records, and perimeter archives. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Rapid-refresh wildfire activity monitoring. +- Temporal alignment of fire activity with smoke and weather products. +- Cross-checking high-frequency fire dynamics against FIRMS hotspots. + +Access +------ + +Use the links below to access the upstream source or its public documentation. + +- `GOES-R Fire/Hot Spot Characterization product `_ +- `GOES-R product page at NOAA STAR `_ + +PyHazards Usage +--------------- + +Use the local GOES-East and GOES-West NetCDF archive as an external inspection-first source for high-frequency wildfire monitoring workflows. + +This dataset is currently documented as an external or inspection-first +source rather than a public ``load_dataset(...)`` entrypoint. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +Use the documented inspection path below to validate local files before training or analysis. + +.. code-block:: bash + + python -m pyhazards.datasets.goesr.inspection --path /home/runyang/ryang/GOES_FDCF_G16/2024 --max-items 10 + +Notes +----- + +- GOES-R FDCF complements FIRMS by trading lower spatial precision for much higher temporal refresh. +- Local copies detected at ``/home/runyang/ryang/GOES_FDCF_G16`` and ``/home/runyang/ryang/GOES_FDCF_G18``. + +Reference +--------- + +- `GOES-R Fire/Hot Spot Characterization product `_. diff --git a/docs/source/datasets/hms_smoke.rst b/docs/source/datasets/hms_smoke.rst new file mode 100644 index 00000000..6f70a5fb --- /dev/null +++ b/docs/source/datasets/hms_smoke.rst @@ -0,0 +1,99 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +HMS Smoke +========= + +NOAA analyst-drawn smoke plume polygons used for smoke tracking, verification, and wildfire smoke exposure analysis. + +Overview +-------- + +HMS Smoke is part of NOAA's Hazard Mapping System, where analysts blend multiple satellite streams to delineate visible smoke plume extent. + +In PyHazards it serves as a smoke-impact companion dataset for wildfire analysis, useful for plume verification, smoke transport evaluation, and exposure-aware wildfire workflows. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - NOAA NESDIS / Hazard Mapping System (HMS) + * - Hazard Family + - Wildfire + * - Source Role + - Smoke Plumes + * - Coverage + - North America, Hawaii, and the Caribbean + * - Geometry + - Analyst-drawn smoke polygons + * - Spatial Resolution + - Vector plume extents with analyst-interpreted boundaries + * - Temporal Resolution + - Sub-daily plume updates + * - Update Cadence + - Near-real-time analyst updates during active smoke events + * - Period of Record + - Ongoing operational archive with historical yearly packages + * - Formats + - Shapefile and zipped archive packages + * - Inspection CLI + - ``ogrinfo -so "/home/runyang/ryang/HMS_Smoke/2024/shapefile/hms_smoke2024.shp" hms_smoke2024`` + +Data Characteristics +-------------------- + +- Polygon smoke extents rather than fire detections or perimeters. +- Interpreted product derived from multiple satellite views and analyst QA. +- Useful for smoke verification and impact mapping, not just fire ignition or spread. +- Complements FIRMS, GOES fire products, and incident perimeter archives. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Smoke plume verification and event analysis. +- Exposure-aware wildfire impact studies. +- Cross-checking smoke extent against active-fire and perimeter products. + +Access +------ + +Use the links below to access the upstream source or its public documentation. + +- `NOAA HMS Fire and Smoke Analysis `_ +- `NASA ARSET overview mentioning HMS smoke product `_ + +PyHazards Usage +--------------- + +Use the local shapefile archive as an external inspection-first source for smoke-plume-aware wildfire workflows. + +This dataset is currently documented as an external or inspection-first +source rather than a public ``load_dataset(...)`` entrypoint. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +Use the documented inspection path below to validate local files before training or analysis. + +.. code-block:: bash + + ogrinfo -so "/home/runyang/ryang/HMS_Smoke/2024/shapefile/hms_smoke2024.shp" hms_smoke2024 + +Notes +----- + +- HMS Smoke is especially useful when you want smoke impact context, not only fire occurrence. +- Local copy detected at ``/home/runyang/ryang/HMS_Smoke``. + +Reference +--------- + +- `NOAA HMS Fire and Smoke Analysis `_. diff --git a/docs/source/datasets/hpwren_weather.rst b/docs/source/datasets/hpwren_weather.rst new file mode 100644 index 00000000..c475716c --- /dev/null +++ b/docs/source/datasets/hpwren_weather.rst @@ -0,0 +1,85 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +HPWREN Weather +============== + +Public HPWREN station feeds for metadata, realtime weather observations, and recent local weather context. + +Overview +-------- + +HPWREN provides wildfire-relevant station observations and station metadata from Southern California mountain and foothill environments. + +In PyHazards it serves as a local weather-station context source for wildfire operations, station-based validation, and regional feature engineering. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - HPWREN / University of California San Diego + * - Hazard Family + - Shared Forcing + * - Source Role + - Weather Stations + * - Coverage + - HPWREN station network footprint + * - Geometry + - Station points with tabular observations + * - Spatial Resolution + - Station-level observations + * - Temporal Resolution + - Minutes to hourly depending on station/feed + * - Update Cadence + - Real-time operational updates plus archived monthly summaries + * - Period of Record + - Local copy spans 2000-2026 + * - Formats + - Text, CSV-style tables, and station metadata files + * - Inspection CLI + - ``find /home/runyang/ryang/HPWREN_Weather -maxdepth 2 -type f | head`` + +Data Characteristics +-------------------- + +- Station-based observations rather than gridded forecasts. +- Useful for local fire-weather context and sanity checks against model forcing. +- Includes metadata, real-time feeds, and historical monthly directories. +- Best used together with gridded forecast products such as HRRR or NDFD. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Local fire-weather validation. +- Station-context feature engineering. +- Regional monitoring dashboards for wildfire operations. + +Access +------ + +- `HPWREN `_ + +PyHazards Usage +--------------- + +Use this local station archive as an inspection-first source for wildfire weather context and QA workflows. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/HPWREN_Weather -maxdepth 2 -type f | head + +Notes +----- + +- Local copy detected at ``/home/runyang/ryang/HPWREN_Weather``. diff --git a/docs/source/datasets/hrrr.rst b/docs/source/datasets/hrrr.rst new file mode 100644 index 00000000..c77866ff --- /dev/null +++ b/docs/source/datasets/hrrr.rst @@ -0,0 +1,86 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +HRRR +==== + +NOAA HRRR forecast layers used for dynamic short-range wildfire weather features. + +Overview +-------- + +HRRR is NOAA's high-resolution rapid-refresh forecast system, updated hourly and designed for short-range weather prediction at convection-allowing resolution. + +In PyHazards it serves as a short-range weather forecast backbone for wildfire prediction and operational forecast feature extraction. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - NOAA HRRR + * - Hazard Family + - Shared Forcing + * - Source Role + - Weather Forecast + * - Coverage + - CONUS-focused forecast domain + * - Geometry + - Gridded forecast fields + * - Spatial Resolution + - About 3 km + * - Temporal Resolution + - Hourly model cycles with short forecast lead times + * - Update Cadence + - Hourly + * - Period of Record + - Local copy spans 2014-2026 with 2024 archive on disk + * - Formats + - GRIB2 and derivative archives + * - Inspection CLI + - ``find /home/runyang/ryang/HRRR/2024 -maxdepth 3 -type f | head`` + +Data Characteristics +-------------------- + +- Short-range numerical weather prediction rather than reanalysis. +- High spatial and temporal refresh for dynamic fire-weather context. +- Useful for forecast-aware wildfire features such as wind, humidity, and precipitation. +- Often paired with NDFD for operational forecast context. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Short-range wildfire weather forecasting features. +- Forecast forcing for next-day wildfire benchmark experiments. +- Operational fire-weather context analysis. + +Access +------ + +- `HRRR official page `_ +- `NOAA HRRR open-data listing `_ + +PyHazards Usage +--------------- + +Use the local archive as an inspection-first short-range forecast source. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/HRRR/2024 -maxdepth 3 -type f | head + +Notes +----- + +- Local copy detected at ``/home/runyang/ryang/HRRR``. diff --git a/docs/source/datasets/landscan_population.rst b/docs/source/datasets/landscan_population.rst new file mode 100644 index 00000000..07c58548 --- /dev/null +++ b/docs/source/datasets/landscan_population.rst @@ -0,0 +1,86 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +LandScan Population +=================== + +ORNL LandScan population package used for population-at-risk and exposure context in wildfire workflows. + +Overview +-------- + +LandScan is ORNL's global gridded population product designed for estimating ambient population distribution. + +In PyHazards it serves as a wildfire exposure and population-at-risk context layer for evaluation, risk modeling, and human-impact analysis. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - Oak Ridge National Laboratory (ORNL) + * - Hazard Family + - Wildfire + * - Source Role + - Population Exposure Context + * - Coverage + - Global + * - Geometry + - Gridded population rasters + * - Spatial Resolution + - About 30 arc-seconds globally + * - Temporal Resolution + - Annual releases + * - Update Cadence + - Release-based / annual + * - Period of Record + - Local copy includes LandScan Global 2024 + * - Formats + - GeoTIFF and extracted raster packages + * - Inspection CLI + - ``find /home/runyang/ryang/LandScan_Global_2024 -maxdepth 3 -type f | head`` + +Data Characteristics +-------------------- + +- Population exposure raster rather than wildfire observations. +- Useful for population-at-risk analysis and human-impact context. +- Global rather than wildfire-specific, but often highly relevant in hazard studies. +- Best paired with incident or perimeter data. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Population-at-risk context for wildfire events. +- Exposure-aware risk mapping. +- Human-impact summaries for benchmark analysis. + +Access +------ + +- `ORNL LandScan Viewer `_ +- `LandScan Global 2024 dataset entry `_ + +PyHazards Usage +--------------- + +Use this local raster package as an inspection-first population context source. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/LandScan_Global_2024 -maxdepth 3 -type f | head + +Notes +----- + +- Local copy detected at ``/home/runyang/ryang/LandScan_Global_2024``. diff --git a/docs/source/datasets/nasa_gibs.rst b/docs/source/datasets/nasa_gibs.rst new file mode 100644 index 00000000..ab5be0f2 --- /dev/null +++ b/docs/source/datasets/nasa_gibs.rst @@ -0,0 +1,85 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +NASA GIBS +========= + +NASA EOSDIS global imagery via WMS/WMTS used for daily true-color satellite imagery. + +Overview +-------- + +NASA GIBS provides easy-to-browse global imagery layers from EOSDIS through map tile and imagery services. + +In PyHazards it acts as daily wildfire scene imagery context and a lightweight remote-sensing browse layer for event inspection. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - NASA EOSDIS / GIBS + * - Hazard Family + - Shared Forcing + * - Source Role + - Satellite Imagery Context + * - Coverage + - Global + * - Geometry + - Tiled imagery and browse layers + * - Spatial Resolution + - Product-dependent imagery resolutions + * - Temporal Resolution + - Daily imagery products + * - Update Cadence + - Daily + * - Period of Record + - Local copy spans 2000-2026 with 2024 imagery subset on disk + * - Formats + - WMTS/WMS layers and downloaded imagery tiles + * - Inspection CLI + - ``find /home/runyang/ryang/NASA_GIBS_2024 -maxdepth 3 -type f | head`` + +Data Characteristics +-------------------- + +- Browse-oriented imagery service rather than analysis-ready tensors. +- Useful for qualitative inspection and event context. +- Global daily coverage across multiple EOSDIS imagery layers. +- Best paired with analytical fire or smoke products when building workflows. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Daily true-color wildfire imagery context. +- Manual inspection of fire events and plume signatures. +- Remote-sensing browse support for benchmark QA. + +Access +------ + +- `NASA GIBS overview `_ + +PyHazards Usage +--------------- + +Use this imagery archive as an inspection-first context source rather than a registry dataset. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/NASA_GIBS_2024 -maxdepth 3 -type f | head + +Notes +----- + +- Local copy detected at ``/home/runyang/ryang/NASA_GIBS_2024``. diff --git a/docs/source/datasets/ndfd.rst b/docs/source/datasets/ndfd.rst new file mode 100644 index 00000000..93053c6d --- /dev/null +++ b/docs/source/datasets/ndfd.rst @@ -0,0 +1,85 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +NDFD +==== + +NOAA NDFD grids and warning bulletins used for wildfire forecast context, watches, warnings, and advisories. + +Overview +-------- + +The National Digital Forecast Database packages official National Weather Service forecast grids and public hazard products. + +In PyHazards it provides operational forecast layers and warning context for wildfire-weather workflows, including critical fire weather and watches/warnings fields. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - NOAA National Weather Service / NDFD + * - Hazard Family + - Shared Forcing + * - Source Role + - Weather Forecast and Watches/Warnings + * - Coverage + - United States public forecast grids + * - Geometry + - Gridded forecast layers and bulletin products + * - Spatial Resolution + - Forecast grid products with variable regional resolution + * - Temporal Resolution + - Hourly to daily depending on field + * - Update Cadence + - Issue-based for hazards and routine forecast refresh for grids + * - Period of Record + - Local copy spans 2000-2026 + * - Formats + - GRIB2, text, and derived grids + * - Inspection CLI + - ``find /home/runyang/ryang/NDFD -maxdepth 2 -type d | head`` + +Data Characteristics +-------------------- + +- Official forecast grids rather than model reanalysis. +- Includes fire-weather relevant variables and warning/advisory products. +- Useful for operational context and downstream feature extraction. +- Complements HRRR when both official forecast products and model guidance are needed. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Fire-weather forecast feature engineering. +- Watches, warnings, and advisories context. +- Operational wildfire decision-support pipelines. + +Access +------ + +- `NDFD / digital.weather.gov `_ + +PyHazards Usage +--------------- + +Use the local NDFD archive as an inspection-first operational forecast source. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/NDFD -maxdepth 2 -type d | head + +Notes +----- + +- Local copy detected at ``/home/runyang/ryang/NDFD``. diff --git a/docs/source/datasets/nohrsc_snodas.rst b/docs/source/datasets/nohrsc_snodas.rst new file mode 100644 index 00000000..9b7f4ecd --- /dev/null +++ b/docs/source/datasets/nohrsc_snodas.rst @@ -0,0 +1,85 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +NOHRSC SNODAS +============= + +NOAA snow-analysis / SNODAS daily archives used as snow-condition context for wildfire-weather studies. + +Overview +-------- + +SNODAS is the Snow Data Assimilation System distributed by NOAA NOHRSC, providing daily gridded snow-condition products. + +In PyHazards it supplies snow-state context that can matter for seasonal fuel curing, hydrologic carryover, and mountain wildfire-weather workflows. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - NOAA NOHRSC / SNODAS + * - Hazard Family + - Shared Forcing + * - Source Role + - Snow Analysis + * - Coverage + - Continental United States + * - Geometry + - Gridded raster fields + * - Spatial Resolution + - About 1 km + * - Temporal Resolution + - Daily + * - Update Cadence + - Daily + * - Period of Record + - Local copy spans 2003-2026 with 2024 archive on disk + * - Formats + - Gridded archives and derived masks + * - Inspection CLI + - ``find /home/runyang/ryang/NOHRSC_SNODAS_masked_2024 -maxdepth 2 -type d | head`` + +Data Characteristics +-------------------- + +- Daily snow-condition product rather than direct wildfire observations. +- Useful as seasonal context for fuel and landscape state. +- Best integrated with weather forcing and topographic context. +- Particularly relevant for mountain and snow-affected regions. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Snow-condition context for wildfire feature engineering. +- Seasonal carryover analysis. +- Landscape-state covariates in western U.S. wildfire workflows. + +Access +------ + +- `NOHRSC archived data and SNODAS description `_ + +PyHazards Usage +--------------- + +Use the local daily archive as an inspection-first forcing source when snow state matters for wildfire context. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/NOHRSC_SNODAS_masked_2024 -maxdepth 2 -type d | head + +Notes +----- + +- Local copy detected at ``/home/runyang/ryang/NOHRSC_SNODAS_masked_2024``. diff --git a/docs/source/datasets/spot_forecast.rst b/docs/source/datasets/spot_forecast.rst new file mode 100644 index 00000000..7e5e8030 --- /dev/null +++ b/docs/source/datasets/spot_forecast.rst @@ -0,0 +1,87 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +Spot Forecast +============= + +NOAA NWS spot forecast products used for incident-specific forecast guidance and operational fire-weather context. + +Overview +-------- + +Spot Forecast products are incident-focused weather forecast products prepared by the National Weather Service for wildfire and emergency operations. + +In PyHazards they provide operational fire-weather context for incident timelines, analyst review, and retrospective comparison with model-based forecast sources. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - NOAA National Weather Service + * - Hazard Family + - Shared Forcing + * - Source Role + - Incident Forecast Guidance + * - Coverage + - Incident-specific forecast products + * - Geometry + - Text and bulletin-style forecast products + * - Spatial Resolution + - Incident/request level + * - Temporal Resolution + - Issue-based + * - Update Cadence + - Generated when requested for active incidents + * - Period of Record + - Local copy spans 2000-2026 + * - Formats + - Text products and support lists + * - Inspection CLI + - ``find /home/runyang/ryang/Spot_Forecast_Current -maxdepth 2 -type f | head`` + +Data Characteristics +-------------------- + +- Operational forecast guidance rather than retrospective climate data. +- Highly incident-specific and request-driven. +- Useful for contextualizing decisions and incident weather expectations. +- Best interpreted alongside broader forecast grids and observations. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Incident-level fire-weather context. +- Retrospective comparison of model forecast versus operational forecast guidance. +- Operational timeline reconstruction. + +Access +------ + +- `NWS Spot Forecast page `_ +- `NWS fire weather resources `_ + +PyHazards Usage +--------------- + +Use this product archive as an inspection-first operational context source rather than a direct ``load_dataset(...)`` path. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/Spot_Forecast_Current -maxdepth 2 -type f | head + +Notes +----- + +- Spot Forecast is best treated as operations context, not a uniform gridded forcing dataset. +- Local copy detected at ``/home/runyang/ryang/Spot_Forecast_Current``. diff --git a/docs/source/datasets/synoptic_weather.rst b/docs/source/datasets/synoptic_weather.rst new file mode 100644 index 00000000..a9e6b9d4 --- /dev/null +++ b/docs/source/datasets/synoptic_weather.rst @@ -0,0 +1,85 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +Synoptic Weather +================ + +Synoptic station metadata and snapshots used for weather-station context in wildfire workflows. + +Overview +-------- + +Synoptic aggregates real-time and historical weather station observations and metadata through a common API and bulk-access workflow. + +In PyHazards it serves as station-based weather context for wildfire operations, event review, and local observation cross-checks. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - Synoptic Data + * - Hazard Family + - Shared Forcing + * - Source Role + - Weather Stations + * - Coverage + - Multi-network station coverage where access is available + * - Geometry + - Station points with tabular observations and metadata + * - Spatial Resolution + - Station-level observations + * - Temporal Resolution + - Minutes to hourly depending on station/network + * - Update Cadence + - Near-real-time for current feeds; historical access depends on plan tier + * - Period of Record + - Local copy spans 2000-2026 in current snapshots + * - Formats + - JSON/CSV-style outputs and metadata tables + * - Inspection CLI + - ``find /home/runyang/ryang/Synoptic_Weather_Current -maxdepth 2 -type f | head`` + +Data Characteristics +-------------------- + +- Station-based observations rather than gridded forecasts. +- Useful for local weather context and network metadata. +- Historical completeness depends on the access tier available at download time. +- Complements HPWREN and forecast grids. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- Weather-station context for wildfire workflows. +- Cross-checking forecast grids against local observations. +- Metadata inspection for station-network selection. + +Access +------ + +- `Synoptic Weather API `_ + +PyHazards Usage +--------------- + +Use the local snapshots as an inspection-first station-context source. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/Synoptic_Weather_Current -maxdepth 2 -type f | head + +Notes +----- + +- Local copies detected at ``/home/runyang/ryang/Synoptic_Weather_Current`` and related Synoptic directories. diff --git a/docs/source/datasets/wrc_housing_density.rst b/docs/source/datasets/wrc_housing_density.rst new file mode 100644 index 00000000..cc500292 --- /dev/null +++ b/docs/source/datasets/wrc_housing_density.rst @@ -0,0 +1,85 @@ +.. This file is generated by scripts/render_dataset_docs.py. Do not edit by hand. + +WRC Housing Density +=================== + +USDA Forest Service housing-density raster used for exposure and population-at-risk context. + +Overview +-------- + +This housing-density raster is part of the Wildfire Risk to Communities ecosystem and encodes where homes or housing units are concentrated in fire-prone landscapes. + +In PyHazards it serves as a static exposure covariate for wildfire risk, WUI context, and population-at-risk analysis. + +At a Glance +----------- + +.. list-table:: + :widths: 28 72 + :stub-columns: 1 + + * - Provider + - USDA Forest Service / Wildfire Risk to Communities + * - Hazard Family + - Wildfire + * - Source Role + - Exposure and Community Context + * - Coverage + - United States + * - Geometry + - Raster exposure layers + * - Spatial Resolution + - About 30 m + * - Temporal Resolution + - Static or release-based + * - Update Cadence + - Release-based + * - Period of Record + - Local copy corresponds to 2018-era housing density package + * - Formats + - Raster packages and extracted tiles + * - Inspection CLI + - ``find /home/runyang/ryang/WRC_Housing_Density -maxdepth 3 -type f | head`` + +Data Characteristics +-------------------- + +- Exposure-focused raster rather than fire detections. +- Useful for WUI and human-exposure context in wildfire modeling. +- Complements fuels, perimeters, and population layers. +- Supports risk interpretation rather than direct fire labeling. + +Typical Use Cases +~~~~~~~~~~~~~~~~~ + +- WUI and housing-exposure covariates. +- Community wildfire risk context. +- Population-at-risk and exposure analysis. + +Access +------ + +- `Wildfire Risk to Communities datasets `_ + +PyHazards Usage +--------------- + +Use this local raster package as an inspection-first wildfire exposure layer. + +Related Coverage +~~~~~~~~~~~~~~~~ + +**Benchmarks:** :doc:`Wildfire Benchmark ` + +Inspection Workflow +------------------- + +.. code-block:: bash + + find /home/runyang/ryang/WRC_Housing_Density -maxdepth 3 -type f | head + +Notes +----- + +- Local copy detected at ``/home/runyang/ryang/WRC_Housing_Density``. diff --git a/docs/source/modules/models_asufm.rst b/docs/source/modules/models_asufm.rst index 388e3e7f..0e136dd5 100644 --- a/docs/source/modules/models_asufm.rst +++ b/docs/source/modules/models_asufm.rst @@ -1,94 +1,44 @@ -.. This file is generated by scripts/render_model_docs.py. Do not edit by hand. - ASUFM ===== -Overview --------- - -``asufm`` is a compact temporal convolution baseline for next-window wildfire activity prediction. - -At a Glance ------------ - -.. grid:: 1 2 4 4 - :gutter: 2 - :class-container: catalog-grid - - .. grid-item-card:: Hazard Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Wildfire - - .. container:: catalog-stat-note - - Public catalog grouping used for this model. - - .. grid-item-card:: Maturity - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Implemented - - .. container:: catalog-stat-note - - Catalog maturity label used on the index page. - - .. grid-item-card:: Tasks - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - 1 - - .. container:: catalog-stat-note - - Forecasting - - .. grid-item-card:: Benchmark Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - :doc:`Wildfire Benchmark ` - - .. container:: catalog-stat-note - - Primary benchmark-family link used for compatible evaluation coverage. - - Description ----------- -``asufm`` is a compact temporal convolution baseline for next-window wildfire activity prediction. +``asufm`` is a self-contained PyHazards port of the ASUFM wildfire model family: +an Attention Swin U-Net with focal modulation for wildfire spread prediction. -PyHazards exposes it through the shared wildfire benchmark and config workflow. +This module follows the official ASUFM configuration pattern with: -Benchmark Compatibility ------------------------ +- ``image_size=64`` +- ``patch_size=4`` +- ``in_channels=6`` +- ``embed_dim=96`` +- ``depths=(2, 2, 2, 2)`` +- ``num_heads=(3, 6, 12, 24)`` +- focal modulation in the encoder +- attention-gated skip connections in the decoder -**Primary benchmark family:** :doc:`Wildfire Benchmark ` +Paper / source +-------------- -External References -------------------- +- `Wildfire Spread Prediction in North America Using Satellite Imagery and Vision Transformer `_ +- `Official repository `_ -**Paper:** `Wildfire Spread Prediction in North America Using Satellite Imagery and Vision Transformer `_ | **Repo:** `Repository `__ +Paper parity note +----------------- -Registry Name -------------- +This PyHazards implementation preserves the main architectural ideas from the +official repository while staying dependency-free inside the main library. +It intentionally replaces the original ``timm``/``einops``-based components +with a native PyTorch implementation of: -Primary entrypoint: ``asufm`` +- patch embedding +- hierarchical Swin-style window attention +- focal modulation in encoder blocks +- U-Net-style decoder with spatially gated skip connections -Supported Tasks ---------------- - -- Forecasting - -Programmatic Use ----------------- +Example of how to use it +------------------------ .. code-block:: python @@ -97,15 +47,12 @@ Programmatic Use model = build_model( name="asufm", - task="forecasting", - input_dim=7, - output_dim=5, - lookback=12, + task="segmentation", + image_size=64, + in_channels=6, + out_dim=1, ) - preds = model(torch.randn(2, 12, 7)) - print(preds.shape) - -Notes ------ -- The smoke path uses weekly wildfire count windows with seasonal time features. + x = torch.randn(2, 6, 64, 64) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_attention_unet.rst b/docs/source/modules/models_attention_unet.rst new file mode 100644 index 00000000..3fdcf7f1 --- /dev/null +++ b/docs/source/modules/models_attention_unet.rst @@ -0,0 +1,42 @@ +Attention U-Net +=============== + +Description +----------- + +``attention_unet`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.attention_unet_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="attention_unet", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_convgru_trajgru.rst b/docs/source/modules/models_convgru_trajgru.rst new file mode 100644 index 00000000..5a3b9a4d --- /dev/null +++ b/docs/source/modules/models_convgru_trajgru.rst @@ -0,0 +1,42 @@ +ConvGRU / TrajGRU +================= + +Description +----------- + +``convgru_trajgru`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.convgru_trajgru_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="convgru_trajgru", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_convlstm.rst b/docs/source/modules/models_convlstm.rst new file mode 100644 index 00000000..ff612614 --- /dev/null +++ b/docs/source/modules/models_convlstm.rst @@ -0,0 +1,42 @@ +ConvLSTM +======== + +Description +----------- + +``convlstm`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.convlstm_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="convlstm", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_deep_ensemble.rst b/docs/source/modules/models_deep_ensemble.rst new file mode 100644 index 00000000..b0a8261e --- /dev/null +++ b/docs/source/modules/models_deep_ensemble.rst @@ -0,0 +1,42 @@ +Deep Ensemble +============= + +Description +----------- + +``deep_ensemble`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.deep_ensemble_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="deep_ensemble", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_deeplabv3p.rst b/docs/source/modules/models_deeplabv3p.rst new file mode 100644 index 00000000..fb4a9503 --- /dev/null +++ b/docs/source/modules/models_deeplabv3p.rst @@ -0,0 +1,42 @@ +DeepLabv3+ +========== + +Description +----------- + +``deeplabv3p`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.deeplabv3p_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="deeplabv3p", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_earthfarseer.rst b/docs/source/modules/models_earthfarseer.rst new file mode 100644 index 00000000..85cc2822 --- /dev/null +++ b/docs/source/modules/models_earthfarseer.rst @@ -0,0 +1,42 @@ +EarthFarseer +============ + +Description +----------- + +``earthfarseer`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.earthfarseer_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="earthfarseer", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_earthformer.rst b/docs/source/modules/models_earthformer.rst new file mode 100644 index 00000000..d050fdcc --- /dev/null +++ b/docs/source/modules/models_earthformer.rst @@ -0,0 +1,42 @@ +Earthformer +=========== + +Description +----------- + +``earthformer`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.earthformer_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="earthformer", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_firecastnet.rst b/docs/source/modules/models_firecastnet.rst index 26970365..381d2bcd 100644 --- a/docs/source/modules/models_firecastnet.rst +++ b/docs/source/modules/models_firecastnet.rst @@ -1,107 +1,47 @@ -.. This file is generated by scripts/render_model_docs.py. Do not edit by hand. - FireCastNet =========== -Overview --------- - -``firecastnet`` is a raster wildfire spread baseline that uses a shallow encoder-decoder architecture. - -At a Glance ------------ - -.. grid:: 1 2 4 4 - :gutter: 2 - :class-container: catalog-grid - - .. grid-item-card:: Hazard Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Wildfire - - .. container:: catalog-stat-note - - Public catalog grouping used for this model. - - .. grid-item-card:: Maturity - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Implemented - - .. container:: catalog-stat-note - - Catalog maturity label used on the index page. - - .. grid-item-card:: Tasks - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - 1 - - .. container:: catalog-stat-note - - Spread - - .. grid-item-card:: Benchmark Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - :doc:`Wildfire Benchmark ` - - .. container:: catalog-stat-note - - Primary benchmark-family link used for compatible evaluation coverage. - - Description ----------- -``firecastnet`` is a raster wildfire spread baseline that uses a shallow encoder-decoder architecture. - -The PyHazards implementation is optimized for the shared smoke benchmark rather than the full upstream training stack. - -Benchmark Compatibility ------------------------ +``firecastnet`` is a lightweight PyHazards port of the FireCastNet model family. -**Primary benchmark family:** :doc:`Wildfire Benchmark ` +This module keeps the main benchmark-relevant ideas needed for integration: -**Mapped benchmark ecosystems:** :doc:`WildfireSpreadTS ` +- compact wildfire-risk raster encoder +- dense decoding head +- forecasting-oriented wildfire output map -External References -------------------- +Paper / source +-------------- -**Paper:** `FireCastNet: Earth-as-a-Graph for Seasonal Fire Prediction `_ | **Repo:** `Repository `__ +- `FireCastNet paper `_ -Registry Name -------------- +Paper parity note +----------------- -Primary entrypoint: ``firecastnet`` +This PyHazards implementation is intentionally **not** a full reproduction of +the original FireCastNet seasonal graph pipeline. Instead, it is a clean +benchmark-facing neural port that preserves the forecasting-oriented wildfire +modeling role needed for PyHazards integration. -Supported Tasks ---------------- +It does not claim architecture or preprocessing parity with the original release. -- Spread - -Programmatic Use ----------------- +Example of how to use it +------------------------ .. code-block:: python import torch from pyhazards.models import build_model - model = build_model(name="firecastnet", task="segmentation", in_channels=12) - logits = model(torch.randn(2, 12, 16, 16)) - print(logits.shape) - -Notes ------ + model = build_model( + name="firecastnet", + task="segmentation", + in_channels=12, + out_channels=1, + ) -- The smoke configuration uses the single-frame wildfire spread raster fixture. + x = torch.randn(2, 12, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_firemm_ir.rst b/docs/source/modules/models_firemm_ir.rst new file mode 100644 index 00000000..1370c986 --- /dev/null +++ b/docs/source/modules/models_firemm_ir.rst @@ -0,0 +1,56 @@ +FireMM-IR +========= + +Description +----------- + +``firemm_ir`` is a benchmark-facing PyHazards port inspired by the FireMM-IR +multi-modal large language model for remote-sensing forest fire monitoring. + +This module preserves the main ideas emphasized by the paper: + +- dual-modality optical + infrared fusion +- class-aware memory +- instruction-conditioned segmentation reasoning +- dense wildfire-scene decoding + +Paper / source +-------------- + +- `FireMM-IR: An Infrared-Enhanced Multi-Modal Large Language Model for Comprehensive Scene Understanding in Remote Sensing Forest Fire Monitoring `_ +- `PubMed entry `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally **not** the original full MLLM +stack with text generation, external instruction tuning, and dataset-specific +serving pipeline. Instead, it is a benchmark-friendly neural port that +preserves the architectural roles needed for PyHazards integration: + +- optical / infrared dual encoder +- class-aware memory enhancement +- instruction-conditioned feature fusion +- dense segmentation head + +It is suitable for smoke testing and benchmark integration, while remaining +transparent about not reproducing the original external MLLM runtime. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="firemm_ir", + task="segmentation", + in_channels=6, + out_dim=1, + ) + + x = torch.randn(2, 6, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_firepred.rst b/docs/source/modules/models_firepred.rst new file mode 100644 index 00000000..919b78cc --- /dev/null +++ b/docs/source/modules/models_firepred.rst @@ -0,0 +1,46 @@ +FirePred +======== + +Description +----------- + +``firepred`` is a PyHazards port inspired by the FirePred wildfire spread model. + +This implementation keeps the benchmark-relevant structure of the published method: + +- multi-temporal wildfire raster input +- separate recent, aggregated, and snapshot branches +- fused CNN decoding for next-step wildfire spread prediction + +Paper / source +-------------- + +- `FirePred GitHub repository `_ +- Paper title used by the official repository: ``FirePred: A hybrid multi-temporal convolutional neural network model for wildfire spread prediction`` + +Paper parity note +----------------- + +This PyHazards implementation is intentionally a lightweight benchmark-facing port. +It preserves the multi-temporal hybrid-CNN pattern while avoiding notebook-only or +project-specific training code from the original release. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="firepred", + task="segmentation", + history=5, + in_channels=8, + out_channels=1, + ) + + x = torch.randn(2, 5, 8, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_forefire.rst b/docs/source/modules/models_forefire.rst index cc0f3b6a..58d1fdec 100644 --- a/docs/source/modules/models_forefire.rst +++ b/docs/source/modules/models_forefire.rst @@ -1,107 +1,62 @@ -.. This file is generated by scripts/render_model_docs.py. Do not edit by hand. - ForeFire Adapter ================ -Overview --------- - -``forefire`` is a deterministic raster adapter that approximates simulator-style front propagation through fixed diffusion kernels. - -At a Glance ------------ - -.. grid:: 1 2 4 4 - :gutter: 2 - :class-container: catalog-grid - - .. grid-item-card:: Hazard Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Wildfire - - .. container:: catalog-stat-note - - Public catalog grouping used for this model. - - .. grid-item-card:: Maturity - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Implemented - - .. container:: catalog-stat-note - - Catalog maturity label used on the index page. - - .. grid-item-card:: Tasks - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - 1 - - .. container:: catalog-stat-note - - Spread - - .. grid-item-card:: Benchmark Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - :doc:`Wildfire Benchmark ` - - .. container:: catalog-stat-note - - Primary benchmark-family link used for compatible evaluation coverage. - - Description ----------- -``forefire`` is a deterministic raster adapter that approximates simulator-style front propagation through fixed diffusion kernels. - -PyHazards exposes it as a benchmarkable baseline through the standard model registry. +``forefire`` is a lightweight PyHazards raster adapter inspired by the +front-propagation behavior of the ForeFire wildfire spread simulator. -Benchmark Compatibility ------------------------ +This module is designed as a benchmark-facing canonical model that keeps the +main local spread mechanism simple and reproducible inside the PyHazards +library: -**Primary benchmark family:** :doc:`Wildfire Benchmark ` +- ``in_channels=12`` +- ``out_channels=1`` +- ``diffusion_steps=2`` by default +- repeated neighborhood spread updates +- explicit fuel and wind modulation -**Mapped benchmark ecosystems:** :doc:`WildfireSpreadTS ` +Paper / source +-------------- -External References -------------------- +- `ForeFire: open source code for wildland fire spread models `_ +- `ForeFire repository `_ -**Paper:** `ForeFire: A Modular, Scriptable C++ Simulation Engine and Library for Wildland-Fire Spread `_ | **Repo:** `Repository `__ +Paper parity note +----------------- -Registry Name -------------- +This PyHazards implementation is intentionally **not** the full ForeFire +simulation system. Instead, it provides a compact raster adapter that captures +the main deterministic spread intuition needed for registry integration and +smoke testing in the main library. -Primary entrypoint: ``forefire`` +The canonical PyHazards version keeps: -Supported Tasks ---------------- +- raster input/output contract +- repeated local front spread updates +- fuel-conditioned spread +- wind-conditioned spread -- Spread +It does not attempt to reproduce the full propagation solver, landscape +representation, or operational simulation stack of the original ForeFire +system. -Programmatic Use ----------------- +Example of how to use it +------------------------ .. code-block:: python import torch from pyhazards.models import build_model - model = build_model(name="forefire", task="segmentation", in_channels=12) - logits = model(torch.randn(2, 12, 16, 16)) - print(logits.shape) - -Notes ------ + model = build_model( + name="forefire", + task="segmentation", + in_channels=12, + diffusion_steps=2, + ) -- This adapter is deterministic and does not learn parameters during the smoke test. + x = torch.randn(2, 12, 32, 32) + spread = model(x) + print(spread.shape) diff --git a/docs/source/modules/models_gemini_25_pro_wildfire_prompted.rst b/docs/source/modules/models_gemini_25_pro_wildfire_prompted.rst new file mode 100644 index 00000000..2e08e013 --- /dev/null +++ b/docs/source/modules/models_gemini_25_pro_wildfire_prompted.rst @@ -0,0 +1,47 @@ +Gemini 2.5 Pro Wildfire Prompted +================================ + +Description +----------- + +``gemini_25_pro_wildfire_prompted`` is a benchmark-facing prompt-conditioned VLM port +inspired by Gemini 2.5 Pro. + +This implementation keeps the integration-relevant structure for a generic wildfire +vision-language baseline: + +- raster wildfire/environment input +- prompt-token conditioning +- visual-token and prompt-token fusion +- dense wildfire-risk decoding + +Paper / source +-------------- + +- `Gemini models documentation `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally not a checkpoint-level port of +Gemini 2.5 Pro. Instead, it is a compact prompt-conditioned wildfire segmentation +baseline that preserves the benchmark-relevant VLM pattern. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="gemini_25_pro_wildfire_prompted", + task="segmentation", + in_channels=6, + out_dim=1, + ) + + x = torch.randn(2, 6, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_internvl3_wildfire_prompted.rst b/docs/source/modules/models_internvl3_wildfire_prompted.rst new file mode 100644 index 00000000..22d9cd35 --- /dev/null +++ b/docs/source/modules/models_internvl3_wildfire_prompted.rst @@ -0,0 +1,47 @@ +InternVL3 Wildfire Prompted +=========================== + +Description +----------- + +``internvl3_wildfire_prompted`` is a benchmark-facing prompt-conditioned VLM port +inspired by InternVL3. + +This implementation keeps the integration-relevant structure for a generic wildfire +vision-language baseline: + +- raster wildfire/environment input +- prompt-token conditioning +- visual-token and prompt-token fusion +- dense wildfire-risk decoding + +Paper / source +-------------- + +- `InternVL repository `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally not a checkpoint-level port of +InternVL3. Instead, it is a compact prompt-conditioned wildfire segmentation +baseline that preserves the benchmark-relevant VLM pattern. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="internvl3_wildfire_prompted", + task="segmentation", + in_channels=6, + out_dim=1, + ) + + x = torch.randn(2, 6, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_lightgbm.rst b/docs/source/modules/models_lightgbm.rst new file mode 100644 index 00000000..e151d712 --- /dev/null +++ b/docs/source/modules/models_lightgbm.rst @@ -0,0 +1,42 @@ +LightGBM +======== + +Description +----------- + +``lightgbm`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.lightgbm_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="lightgbm", + task="classification", + ) + + if "classification" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_llama4_wildfire_prompted.rst b/docs/source/modules/models_llama4_wildfire_prompted.rst new file mode 100644 index 00000000..dfabc6d1 --- /dev/null +++ b/docs/source/modules/models_llama4_wildfire_prompted.rst @@ -0,0 +1,48 @@ +Llama 4 Wildfire Prompted +========================= + +Description +----------- + +``llama4_wildfire_prompted`` is a benchmark-facing prompt-conditioned multimodal port +inspired by Meta Llama 4. + +This implementation keeps the integration-relevant structure for a generic wildfire +vision-language baseline: + +- raster wildfire/environment input +- prompt-token conditioning +- visual-token and prompt-token fusion +- dense wildfire-risk decoding + +Paper / source +-------------- + +- `Meta Llama organization `_ +- `Llama site `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally not a checkpoint-level port of +Llama 4. Instead, it is a compact prompt-conditioned wildfire segmentation +baseline that preserves the benchmark-relevant multimodal reasoning pattern. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="llama4_wildfire_prompted", + task="segmentation", + in_channels=6, + out_dim=1, + ) + + x = torch.randn(2, 6, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_logistic_regression.rst b/docs/source/modules/models_logistic_regression.rst new file mode 100644 index 00000000..8fde8cac --- /dev/null +++ b/docs/source/modules/models_logistic_regression.rst @@ -0,0 +1,42 @@ +Logistic Regression +=================== + +Description +----------- + +``logistic_regression`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.logistic_regression_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="logistic_regression", + task="classification", + ) + + if "classification" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_mau.rst b/docs/source/modules/models_mau.rst new file mode 100644 index 00000000..bb4c78a0 --- /dev/null +++ b/docs/source/modules/models_mau.rst @@ -0,0 +1,42 @@ +MAU +=== + +Description +----------- + +``mau`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.mau_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="mau", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_modis_active_fire_c61.rst b/docs/source/modules/models_modis_active_fire_c61.rst new file mode 100644 index 00000000..ad65c1fa --- /dev/null +++ b/docs/source/modules/models_modis_active_fire_c61.rst @@ -0,0 +1,48 @@ +MODIS Active Fire C6.1 +====================== + +Description +----------- + +``modis_active_fire_c61`` is a PyHazards operational-detection baseline inspired by +NASA's MODIS Collection 6.1 active-fire algorithm and its FIRMS-facing use in practice. + +This implementation keeps the benchmark-relevant structure of the published method: + +- satellite active-fire detection framing rather than generic segmentation +- contextual thermal anomaly estimation at coarser MODIS-like support +- split-window style evidence between mid-IR and longwave channels +- lightweight learnable calibration head so the method can run under the PyHazards benchmark contract + +Paper / source +-------------- + +- `MODIS Land Team fire page `_ +- `Giglio et al. (2016) `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally a benchmark-facing surrogate rather than a byte-for-byte +reproduction of the NASA operational code path. It preserves the operational-detection intuition of +contextual thermal anomaly plus spectral evidence, while adding a compact learnable calibration head so +that smoke runs can generate standard training artifacts. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="modis_active_fire_c61", + task="segmentation", + in_channels=5, + out_dim=1, + ) + + x = torch.randn(2, 5, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_predrnn_v2.rst b/docs/source/modules/models_predrnn_v2.rst new file mode 100644 index 00000000..448d31da --- /dev/null +++ b/docs/source/modules/models_predrnn_v2.rst @@ -0,0 +1,42 @@ +PredRNN-v2 +========== + +Description +----------- + +``predrnn_v2`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.predrnn_v2_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="predrnn_v2", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_prithvi_burnscars.rst b/docs/source/modules/models_prithvi_burnscars.rst new file mode 100644 index 00000000..bbe1b5fe --- /dev/null +++ b/docs/source/modules/models_prithvi_burnscars.rst @@ -0,0 +1,55 @@ +Prithvi BurnScars +================= + +Description +----------- + +``prithvi_burnscars`` is a lightweight PyHazards downstream segmentation model +inspired by the official Prithvi BurnScars release. + +This module keeps the benchmark-relevant ideas from the model card: + +- Prithvi-style EO temporal backbone +- single-timestamp or arbitrary-timestamp fine-tuning support +- burn-scar-style segmentation head +- U-Net-like skip fusion for dense output + +Paper / source +-------------- + +- `Prithvi-EO-2.0: A Versatile Multi-Temporal Foundation Model for Earth Observation Applications `_ +- `Prithvi-EO-2.0-300M-BurnScars model card `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally **not** the official released +checkpoint. Instead, it is a benchmark-facing downstream port that preserves the +main architectural story of the official BurnScars release: + +- EO foundation-style encoder +- downstream burn-scar segmentation objective +- dense decoder with skip fusion + +It is suitable for PyHazards integration and smoke testing, while remaining +transparent about not being a weight-identical reproduction. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="prithvi_burnscars", + task="segmentation", + image_size=32, + in_channels=6, + out_dim=1, + ) + + x = torch.randn(2, 1, 6, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_prithvi_eo_2_tl.rst b/docs/source/modules/models_prithvi_eo_2_tl.rst new file mode 100644 index 00000000..9113b19d --- /dev/null +++ b/docs/source/modules/models_prithvi_eo_2_tl.rst @@ -0,0 +1,57 @@ +Prithvi-EO-2.0-TL +================= + +Description +----------- + +``prithvi_eo_2_tl`` is a lightweight PyHazards port inspired by the +Prithvi-EO-2.0 transfer-learning model family. + +This module keeps the main ideas highlighted in the official paper/model card: + +- multi-temporal EO input sequences +- temporal embeddings +- location embeddings +- transformer-style EO backbone +- segmentation-ready downstream head + +Paper / source +-------------- + +- `Prithvi-EO-2.0: A Versatile Multi-Temporal Foundation Model for Earth Observation Applications `_ +- `Prithvi-EO-2.0-300M-TL model card `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally **not** the full official +pretrained foundation model with released checkpoints. Instead, it is a clean +PyTorch port that preserves the benchmark-relevant architectural ideas needed +for PyHazards integration: + +- sequence-based EO input handling +- temporal and location conditioning +- transformer encoder over patch tokens +- downstream segmentation decoding + +It does not claim checkpoint parity with the official IBM-NASA release. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="prithvi_eo_2_tl", + task="segmentation", + image_size=32, + in_channels=6, + out_dim=1, + ) + + x = torch.randn(2, 4, 6, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_prithvi_wxc.rst b/docs/source/modules/models_prithvi_wxc.rst new file mode 100644 index 00000000..7a470547 --- /dev/null +++ b/docs/source/modules/models_prithvi_wxc.rst @@ -0,0 +1,57 @@ +Prithvi-WxC +=========== + +Description +----------- + +``prithvi_wxc`` is a lightweight PyHazards port inspired by the +Prithvi-WxC weather-climate foundation-model family. + +This module keeps the main ideas highlighted in the official paper/model card: + +- multi-step weather input sequences +- lead-time conditioning +- variable-summary conditioning +- transformer-style weather backbone +- dense downstream head for wildfire-style grid prediction + +Paper / source +-------------- + +- `Prithvi WxC: Foundation Model for Weather and Climate `_ +- `Prithvi-WxC model card `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally **not** the full official +pretrained Prithvi-WxC checkpoint stack. Instead, it is a clean PyTorch port +that preserves the benchmark-relevant ideas we need for integration: + +- multi-variable weather sequence handling +- lead-time-aware conditioning +- transformer encoder over weather patch tokens +- dense wildfire-risk decoding + +It does not claim checkpoint parity with the official NASA/IBM release. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="prithvi_wxc", + task="segmentation", + image_size=32, + in_channels=8, + out_dim=1, + ) + + x = torch.randn(2, 5, 8, 32, 32) + lead_time = torch.linspace(6.0, 30.0, 5).repeat(2, 1) + logits = model({"x": x, "lead_time_hours": lead_time}) + print(logits.shape) diff --git a/docs/source/modules/models_qwen25_vl_wildfire_prompted.rst b/docs/source/modules/models_qwen25_vl_wildfire_prompted.rst new file mode 100644 index 00000000..d0a2c7d5 --- /dev/null +++ b/docs/source/modules/models_qwen25_vl_wildfire_prompted.rst @@ -0,0 +1,52 @@ +Qwen2.5-VL Wildfire Prompted +============================ + +Description +----------- + +``qwen25_vl_wildfire_prompted`` is a benchmark-facing prompt-conditioned VLM port +inspired by Qwen2.5-VL. + +This implementation keeps the integration-relevant structure for a generic wildfire +vision-language baseline: + +- raster wildfire/environment input +- prompt-token conditioning +- visual-token and prompt-token fusion +- dense wildfire-risk decoding + +Paper / source +-------------- + +- `QwenLM/Qwen2.5-VL GitHub repository `_ +- `Qwen2.5-VL Technical Report `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally not a full parameter-port of the +released Qwen2.5-VL checkpoints. Instead, it is a compact prompt-conditioned +wildfire segmentation baseline that preserves the benchmark-relevant VLM pattern: + +- prompt-conditioned visual reasoning +- image-token and prompt-token fusion +- dense downstream wildfire prediction head + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="qwen25_vl_wildfire_prompted", + task="segmentation", + in_channels=6, + out_dim=1, + ) + + x = torch.randn(2, 6, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_rainformer.rst b/docs/source/modules/models_rainformer.rst new file mode 100644 index 00000000..b5cf7329 --- /dev/null +++ b/docs/source/modules/models_rainformer.rst @@ -0,0 +1,42 @@ +Rainformer +========== + +Description +----------- + +``rainformer`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.rainformer_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="rainformer", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_random_forest.rst b/docs/source/modules/models_random_forest.rst new file mode 100644 index 00000000..3cb20ff6 --- /dev/null +++ b/docs/source/modules/models_random_forest.rst @@ -0,0 +1,42 @@ +Random Forest +============= + +Description +----------- + +``random_forest`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.random_forest_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="random_forest", + task="classification", + ) + + if "classification" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_resnet18_unet.rst b/docs/source/modules/models_resnet18_unet.rst new file mode 100644 index 00000000..d55b5403 --- /dev/null +++ b/docs/source/modules/models_resnet18_unet.rst @@ -0,0 +1,42 @@ +ResNet-18 U-Net +=============== + +Description +----------- + +``resnet18_unet`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.resnet18_unet_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="resnet18_unet", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_segformer.rst b/docs/source/modules/models_segformer.rst new file mode 100644 index 00000000..429688b6 --- /dev/null +++ b/docs/source/modules/models_segformer.rst @@ -0,0 +1,42 @@ +SegFormer +========= + +Description +----------- + +``segformer`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.segformer_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="segformer", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_swin_unet.rst b/docs/source/modules/models_swin_unet.rst new file mode 100644 index 00000000..32f4cfd5 --- /dev/null +++ b/docs/source/modules/models_swin_unet.rst @@ -0,0 +1,42 @@ +Swin-UNet +========= + +Description +----------- + +``swin_unet`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.swin_unet_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="swin_unet", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_swinlstm.rst b/docs/source/modules/models_swinlstm.rst new file mode 100644 index 00000000..1a576ebd --- /dev/null +++ b/docs/source/modules/models_swinlstm.rst @@ -0,0 +1,42 @@ +SwinLSTM +======== + +Description +----------- + +``swinlstm`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.swinlstm_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="swinlstm", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_tcn.rst b/docs/source/modules/models_tcn.rst new file mode 100644 index 00000000..46f7055a --- /dev/null +++ b/docs/source/modules/models_tcn.rst @@ -0,0 +1,42 @@ +TCN +=== + +Description +----------- + +``tcn`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.tcn_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="tcn", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_ts_satfire.rst b/docs/source/modules/models_ts_satfire.rst new file mode 100644 index 00000000..01c215da --- /dev/null +++ b/docs/source/modules/models_ts_satfire.rst @@ -0,0 +1,50 @@ +TS-SatFire +========== + +Description +----------- + +``ts_satfire`` is a lightweight PyHazards port inspired by the TS-SatFire +multi-temporal wildfire prediction benchmark family. + +This module keeps the benchmark-relevant ideas we need for integration: + +- multi-temporal satellite image sequences +- auxiliary environmental channels +- spatio-temporal raster encoding +- dense wildfire progression prediction + +Paper / source +-------------- + +- `TS-SatFire paper `_ +- `TS-SatFire official repository `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally **not** the entire official +TS-SatFire processing and benchmark stack. Instead, it is a clean spatio-temporal +port that preserves the prediction-task modeling role needed for benchmark integration. + +It does not claim exact architecture, dataset, or training parity with the original release. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="ts_satfire", + task="segmentation", + history=5, + in_channels=8, + out_channels=1, + ) + + x = torch.randn(2, 5, 8, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_unet.rst b/docs/source/modules/models_unet.rst new file mode 100644 index 00000000..375a0a2a --- /dev/null +++ b/docs/source/modules/models_unet.rst @@ -0,0 +1,42 @@ +U-Net +===== + +Description +----------- + +``unet`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.unet_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="unet", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_utae.rst b/docs/source/modules/models_utae.rst new file mode 100644 index 00000000..1dcac233 --- /dev/null +++ b/docs/source/modules/models_utae.rst @@ -0,0 +1,42 @@ +UTAE +==== + +Description +----------- + +``utae`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.utae_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="utae", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_viirs_375m_active_fire.rst b/docs/source/modules/models_viirs_375m_active_fire.rst new file mode 100644 index 00000000..9cbfd9f2 --- /dev/null +++ b/docs/source/modules/models_viirs_375m_active_fire.rst @@ -0,0 +1,48 @@ +VIIRS 375 m Active Fire +======================= + +Description +----------- + +``viirs_375m_active_fire`` is a PyHazards operational-detection baseline inspired by +NASA's VIIRS 375 m active-fire algorithm and its FIRMS-facing use in practice. + +This implementation keeps the benchmark-relevant structure of the published method: + +- satellite active-fire detection framing rather than generic segmentation +- contextual thermal anomaly estimation +- split-window style evidence between mid-IR and longwave channels +- lightweight learnable calibration head so the method can run under the PyHazards benchmark contract + +Paper / source +-------------- + +- `NASA Earthdata VIIRS I-Band 375 m Active Fire page `_ +- `Schroeder et al. (2014) `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally a benchmark-facing surrogate rather than a byte-for-byte +reproduction of the NASA operational code path. It preserves the operational-detection intuition of +contextual thermal anomaly plus spectral evidence, while adding a compact learnable calibration head so +that smoke runs can generate standard training artifacts. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="viirs_375m_active_fire", + task="segmentation", + in_channels=5, + out_dim=1, + ) + + x = torch.randn(2, 5, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_vit_segmenter.rst b/docs/source/modules/models_vit_segmenter.rst new file mode 100644 index 00000000..a4c77307 --- /dev/null +++ b/docs/source/modules/models_vit_segmenter.rst @@ -0,0 +1,42 @@ +ViT Segmenter +============= + +Description +----------- + +``vit_segmenter`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.vit_segmenter_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="vit_segmenter", + task="segmentation", + ) + + if "segmentation" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/modules/models_wildfire_forecasting.rst b/docs/source/modules/models_wildfire_forecasting.rst deleted file mode 100644 index 82a6b73a..00000000 --- a/docs/source/modules/models_wildfire_forecasting.rst +++ /dev/null @@ -1,111 +0,0 @@ -.. This file is generated by scripts/render_model_docs.py. Do not edit by hand. - -Wildfire Forecasting -==================== - -Overview --------- - -``wildfire_forecasting`` is a compact GRU-attention forecaster for weekly wildfire activity windows. - -At a Glance ------------ - -.. grid:: 1 2 4 4 - :gutter: 2 - :class-container: catalog-grid - - .. grid-item-card:: Hazard Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Wildfire - - .. container:: catalog-stat-note - - Public catalog grouping used for this model. - - .. grid-item-card:: Maturity - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Implemented - - .. container:: catalog-stat-note - - Catalog maturity label used on the index page. - - .. grid-item-card:: Tasks - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - 1 - - .. container:: catalog-stat-note - - Forecasting - - .. grid-item-card:: Benchmark Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - :doc:`Wildfire Benchmark ` - - .. container:: catalog-stat-note - - Primary benchmark-family link used for compatible evaluation coverage. - - -Description ------------ - -``wildfire_forecasting`` is a compact GRU-attention forecaster for weekly wildfire activity windows. - -The PyHazards implementation targets smoke-testable next-window size-group prediction through the shared wildfire benchmark flow. - -Benchmark Compatibility ------------------------ - -**Primary benchmark family:** :doc:`Wildfire Benchmark ` - -External References -------------------- - -**Paper:** `Wildfire Danger Prediction and Understanding with Deep Learning `_ | **Repo:** `Repository `__ - -Registry Name -------------- - -Primary entrypoint: ``wildfire_forecasting`` - -Supported Tasks ---------------- - -- Forecasting - -Programmatic Use ----------------- - -.. code-block:: python - - import torch - from pyhazards.models import build_model - - model = build_model( - name="wildfire_forecasting", - task="forecasting", - input_dim=7, - output_dim=5, - lookback=12, - ) - preds = model(torch.randn(2, 12, 7)) - print(preds.shape) - -Notes ------ - -- This public adapter is exercised on the weekly wildfire smoke benchmark. diff --git a/docs/source/modules/models_wildfiregpt.rst b/docs/source/modules/models_wildfiregpt.rst new file mode 100644 index 00000000..67aabd7e --- /dev/null +++ b/docs/source/modules/models_wildfiregpt.rst @@ -0,0 +1,55 @@ +WildfireGPT +=========== + +Description +----------- + +``wildfiregpt`` is a benchmark-facing PyHazards port inspired by the +WildfireGPT multi-agent retrieval-augmented generation system. + +This module preserves the main ideas emphasized by the official paper/repository: + +- user profile conditioning +- planning / analyst style system-role tokens +- retrieved knowledge conditioning +- decision-support style fusion before producing a wildfire risk map + +Paper / source +-------------- + +- `MARSHA: multi-agent RAG system for hazard adaptation `_ +- `WildfireGPT repository `_ + +Paper parity note +----------------- + +This PyHazards implementation is intentionally **not** the original Streamlit + +OpenAI Assistant API system. Instead, it is a benchmark-friendly neural port +that preserves the architectural roles needed for PyHazards integration: + +- user-profile representation +- retrieved-context representation +- multi-agent style orchestration tokens +- downstream wildfire risk decoding + +It is suitable for smoke testing and benchmark integration, while remaining +transparent about not reproducing the external hosted LLM stack. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="wildfiregpt", + task="segmentation", + in_channels=12, + out_dim=1, + ) + + x = torch.randn(2, 12, 32, 32) + logits = model(x) + print(logits.shape) diff --git a/docs/source/modules/models_wrf_sfire.rst b/docs/source/modules/models_wrf_sfire.rst index 85546f73..84caef96 100644 --- a/docs/source/modules/models_wrf_sfire.rst +++ b/docs/source/modules/models_wrf_sfire.rst @@ -1,107 +1,60 @@ -.. This file is generated by scripts/render_model_docs.py. Do not edit by hand. - WRF-SFIRE Adapter ================= -Overview --------- - -``wrf_sfire`` approximates simulator-style spread transport with a fixed diffusion and terrain-moisture modulation layer. - -At a Glance ------------ - -.. grid:: 1 2 4 4 - :gutter: 2 - :class-container: catalog-grid - - .. grid-item-card:: Hazard Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Wildfire - - .. container:: catalog-stat-note - - Public catalog grouping used for this model. - - .. grid-item-card:: Maturity - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - Implemented - - .. container:: catalog-stat-note - - Catalog maturity label used on the index page. - - .. grid-item-card:: Tasks - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - 1 - - .. container:: catalog-stat-note - - Spread - - .. grid-item-card:: Benchmark Family - :class-card: catalog-stat-card - - .. container:: catalog-stat-value - - :doc:`Wildfire Benchmark ` - - .. container:: catalog-stat-note - - Primary benchmark-family link used for compatible evaluation coverage. - - Description ----------- -``wrf_sfire`` approximates simulator-style spread transport with a fixed diffusion and terrain-moisture modulation layer. - -The PyHazards adapter is designed for consistent smoke benchmarking rather than full physical simulation. +``wrf_sfire`` is a lightweight PyHazards raster adapter inspired by the +transport-and-diffusion behavior of the WRF-SFIRE wildfire spread system. -Benchmark Compatibility ------------------------ +This module is designed as a benchmark-facing canonical model that keeps the +main spread intuition simple inside the PyHazards library: -**Primary benchmark family:** :doc:`Wildfire Benchmark ` +- ``in_channels=12`` +- ``out_channels=1`` +- ``diffusion_steps=3`` by default +- local transport via a fixed spread kernel +- terrain and moisture modulation during repeated spread steps -**Mapped benchmark ecosystems:** :doc:`WildfireSpreadTS ` +Paper / source +-------------- -External References -------------------- +- `Coupled atmosphere-wildland fire modeling with WRF 3.3 and SFIRE 2011 `_ +- `WRF-SFIRE repository `_ -**Paper:** `Coupled atmosphere-wildland fire modeling with WRF 3.3 and SFIRE 2011 `_ | **Repo:** `Repository `__ +Paper parity note +----------------- -Registry Name -------------- +This PyHazards implementation is intentionally **not** the full WRF-SFIRE +coupled simulator. Instead, it provides a compact raster adapter that preserves +the main local-spread intuition needed for benchmark integration and smoke +testing inside the main library. -Primary entrypoint: ``wrf_sfire`` +The canonical PyHazards version keeps: -Supported Tasks ---------------- +- raster input/output contract +- repeated local diffusion +- terrain-aware spread scaling +- moisture damping -- Spread +It does not attempt to reproduce the full atmospheric coupling, mesh handling, +or solver stack of the original WRF-SFIRE system. -Programmatic Use ----------------- +Example of how to use it +------------------------ .. code-block:: python import torch from pyhazards.models import build_model - model = build_model(name="wrf_sfire", task="segmentation", in_channels=12) - logits = model(torch.randn(2, 12, 16, 16)) - print(logits.shape) - -Notes ------ + model = build_model( + name="wrf_sfire", + task="segmentation", + in_channels=12, + diffusion_steps=3, + ) -- This smoke-path adapter keeps the simulator slot benchmarkable without external binaries. + x = torch.randn(2, 12, 32, 32) + spread = model(x) + print(spread.shape) diff --git a/docs/source/modules/models_xgboost.rst b/docs/source/modules/models_xgboost.rst new file mode 100644 index 00000000..b47c7bda --- /dev/null +++ b/docs/source/modules/models_xgboost.rst @@ -0,0 +1,42 @@ +XGBoost +======= + +Description +----------- + +``xgboost`` is the canonical PyHazards promotion of the wildfire benchmark Track-O baseline. + +This module is kept in ``pyhazards.models`` so the main branch can treat the benchmark baseline as a first-class model implementation. + +It is primarily intended for benchmark integration, smoke testing, and registry-based construction through ``build_model(...)``. + +Paper / source +-------------- + +- Promoted from the wildfire benchmark Track-O model family in PyHazards. +- Source implementation lineage: ``pyhazards.pipelines.wildfire_benchmark.models.xgboost_track_o``. + +Paper parity note +----------------- + +This PyHazards implementation is intentionally benchmark-facing. It preserves the modeling role of the Track-O baseline while making the model available from the main ``pyhazards.models`` layer. + +Example of how to use it +------------------------ + +.. code-block:: python + + import torch + from pyhazards.models import build_model + + model = build_model( + name="xgboost", + task="classification", + ) + + if "classification" == "classification": + x = torch.randn(4, 16) + else: + x = torch.randn(2, 1, 32, 32) + out = model(x) + print(type(out)) diff --git a/docs/source/pyhazards_datasets.rst b/docs/source/pyhazards_datasets.rst index 31b4f17a..2064dd14 100644 --- a/docs/source/pyhazards_datasets.rst +++ b/docs/source/pyhazards_datasets.rst @@ -30,7 +30,7 @@ At a Glance .. container:: catalog-stat-value - 20 + 34 .. container:: catalog-stat-note @@ -41,7 +41,7 @@ At a Glance .. container:: catalog-stat-value - 10 + 24 .. container:: catalog-stat-note @@ -174,6 +174,254 @@ primary source, and the most relevant inspection or registry surface. **Primary Source:** `Gelaro et al. (2017). The Modern-Era Retrospective Analysis for Research and Applications, Version 2 (MERRA-2). `_ + .. grid-item-card:: HPWREN Weather + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + Public HPWREN station feeds used for local weather-station context in wildfire operations and validation workflows. + + .. container:: catalog-chip-row + + :bdg-secondary:`Weather Stations` :bdg-info:`Station points with tabular observations` + + .. container:: catalog-meta-row + + **Coverage:** HPWREN station network footprint + + .. container:: catalog-meta-row + + **Update Cadence:** Real-time operational updates plus archived monthly summaries + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/HPWREN_Weather -maxdepth 2 -type f | head`` + + .. container:: catalog-link-row + + **Details:** :doc:`HPWREN Weather ` + + .. container:: catalog-link-row + + **Primary Source:** `HPWREN `_ + + .. grid-item-card:: Spot Forecast + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + NOAA NWS spot forecast products used for incident-specific forecast guidance and fire-weather context. + + .. container:: catalog-chip-row + + :bdg-secondary:`Incident Forecast Guidance` :bdg-info:`Text and bulletin-style products` + + .. container:: catalog-meta-row + + **Coverage:** Incident-specific forecast products + + .. container:: catalog-meta-row + + **Update Cadence:** Generated when requested for active incidents + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/Spot_Forecast_Current -maxdepth 2 -type f | head`` + + .. container:: catalog-link-row + + **Details:** :doc:`Spot Forecast ` + + .. container:: catalog-link-row + + **Primary Source:** `NWS Spot Forecast page `_ + + .. grid-item-card:: NOHRSC SNODAS + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + Daily snow-analysis grids used as snow-state context for mountain wildfire and seasonal fuel workflows. + + .. container:: catalog-chip-row + + :bdg-secondary:`Snow Analysis` :bdg-info:`Gridded raster fields` + + .. container:: catalog-meta-row + + **Coverage:** Continental United States + + .. container:: catalog-meta-row + + **Update Cadence:** Daily + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/NOHRSC_SNODAS_masked_2024 -maxdepth 2 -type d | head`` + + .. container:: catalog-link-row + + **Details:** :doc:`NOHRSC SNODAS ` + + .. container:: catalog-link-row + + **Primary Source:** `NOHRSC archived data and SNODAS description `_ + + .. grid-item-card:: HRRR + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + NOAA's rapid-refresh forecast system used for short-range wildfire weather features and forecast forcing. + + .. container:: catalog-chip-row + + :bdg-secondary:`Weather Forecast` :bdg-info:`Gridded forecast fields` + + .. container:: catalog-meta-row + + **Coverage:** CONUS-focused forecast domain + + .. container:: catalog-meta-row + + **Update Cadence:** Hourly + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/HRRR/2024 -maxdepth 3 -type f | head`` + + .. container:: catalog-link-row + + **Details:** :doc:`HRRR ` + + .. container:: catalog-link-row + + **Primary Source:** `HRRR official page `_ + + .. grid-item-card:: NDFD + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + Official NWS forecast grids and warning products used for wildfire-weather context and public hazard overlays. + + .. container:: catalog-chip-row + + :bdg-secondary:`Forecast and Warnings` :bdg-info:`Gridded forecast layers and bulletins` + + .. container:: catalog-meta-row + + **Coverage:** United States public forecast grids + + .. container:: catalog-meta-row + + **Update Cadence:** Issue-based for hazards and routine forecast refresh for grids + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/NDFD -maxdepth 2 -type d | head`` + + .. container:: catalog-link-row + + **Details:** :doc:`NDFD ` + + .. container:: catalog-link-row + + **Primary Source:** `NDFD / digital.weather.gov `_ + + .. grid-item-card:: GOES GeoColor + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + GOES GeoColor imagery used for rapid visual fire-scene context and plume inspection. + + .. container:: catalog-chip-row + + :bdg-secondary:`Satellite Imagery Context` :bdg-info:`Geostationary imagery time series` + + .. container:: catalog-meta-row + + **Coverage:** GOES-East and GOES-West views over the Americas + + .. container:: catalog-meta-row + + **Update Cadence:** Continuous ingest as new imagery becomes available + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/GOES_GeoColor_CIRA -maxdepth 3 -type f | head`` + + .. container:: catalog-link-row + + **Details:** :doc:`GOES GeoColor ` + + .. container:: catalog-link-row + + **Primary Source:** `CIRA Slider `_ + + .. grid-item-card:: NASA GIBS + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + NASA EOSDIS browse imagery used for daily wildfire scene context and qualitative event inspection. + + .. container:: catalog-chip-row + + :bdg-secondary:`Satellite Imagery Context` :bdg-info:`Tiled imagery and browse layers` + + .. container:: catalog-meta-row + + **Coverage:** Global + + .. container:: catalog-meta-row + + **Update Cadence:** Daily + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/NASA_GIBS_2024 -maxdepth 3 -type f | head`` + + .. container:: catalog-link-row + + **Details:** :doc:`NASA GIBS ` + + .. container:: catalog-link-row + + **Primary Source:** `NASA GIBS overview `_ + + .. grid-item-card:: Synoptic Weather + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + Synoptic weather-station feeds used for local observation context and wildfire weather cross-checks. + + .. container:: catalog-chip-row + + :bdg-secondary:`Weather Stations` :bdg-info:`Station points with tabular observations and metadata` + + .. container:: catalog-meta-row + + **Coverage:** Multi-network station coverage where access is available + + .. container:: catalog-meta-row + + **Update Cadence:** Near-real-time for current feeds; historical access depends on plan tier + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/Synoptic_Weather_Current -maxdepth 2 -type f | head`` + + .. container:: catalog-link-row + + **Details:** :doc:`Synoptic Weather ` + + .. container:: catalog-link-row + + **Primary Source:** `Synoptic Weather API `_ + .. tab-item:: Wildfire @@ -397,6 +645,216 @@ primary source, and the most relevant inspection or registry surface. **Primary Source:** `National Interagency Fire Center. Wildland Fire Incident Geospatial Services (WFIGS). `_ + .. grid-item-card:: FRAP Fire Perimeters + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + California's authoritative historical fire perimeter archive maintained by CAL FIRE FRAP. + + .. container:: catalog-chip-row + + :bdg-secondary:`Historical Perimeters` :bdg-info:`Vector fire perimeter polygons` + + .. container:: catalog-meta-row + + **Coverage:** California + + .. container:: catalog-meta-row + + **Update Cadence:** Annual spring releases with new fire-season perimeters + + .. container:: catalog-meta-row + + **Inspection:** ``ogrinfo -so "/home/runyang/ryang/FRAP_Fire_Perimeters/shapefile/California_Fire_Perimeters_(all).shp" "California_Fire_Perimeters_(all)"`` + + .. container:: catalog-meta-row + + **Related Benchmarks:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Details:** :doc:`FRAP Fire Perimeters ` + + .. container:: catalog-link-row + + **Primary Source:** `CAL FIRE FRAP Fire Perimeters `_ + + .. grid-item-card:: GeoMAC Historical + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + Historical GeoMAC wildfire perimeters preserved as a legacy U.S. perimeter archive for long-horizon evaluation. + + .. container:: catalog-chip-row + + :bdg-secondary:`Historical Perimeters` :bdg-info:`Archived wildfire perimeter polygons` + + .. container:: catalog-meta-row + + **Coverage:** United States + + .. container:: catalog-meta-row + + **Update Cadence:** Legacy archive; local copy is static + + .. container:: catalog-meta-row + + **Inspection:** ``unzip -l "/home/runyang/ryang/GeoMAC_Historical/Historic_Geomac_Perimeters_All_Years_2000_2018/Historic_Geomac_Perimeters_All_Years_2000_2018.zip" | head`` + + .. container:: catalog-meta-row + + **Related Benchmarks:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Details:** :doc:`GeoMAC Historical ` + + .. container:: catalog-link-row + + **Primary Source:** `USGS GeoMAC historical archive description `_ + + .. grid-item-card:: HMS Smoke + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + NOAA analyst-drawn smoke plume polygons used for smoke tracking, verification, and wildfire smoke exposure analysis. + + .. container:: catalog-chip-row + + :bdg-secondary:`Smoke Plumes` :bdg-info:`Vector smoke polygons` + + .. container:: catalog-meta-row + + **Coverage:** North America, Hawaii, and the Caribbean + + .. container:: catalog-meta-row + + **Update Cadence:** Sub-daily near-real-time analyst updates + + .. container:: catalog-meta-row + + **Inspection:** ``ogrinfo -so "/home/runyang/ryang/HMS_Smoke/2024/shapefile/hms_smoke2024.shp" hms_smoke2024`` + + .. container:: catalog-meta-row + + **Related Benchmarks:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Details:** :doc:`HMS Smoke ` + + .. container:: catalog-link-row + + **Primary Source:** `NOAA HMS Fire and Smoke Analysis `_ + + .. grid-item-card:: GOES-R FDCF + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + GOES-R ABI Fire/Hot Spot Characterization files used for high-frequency active-fire monitoring across the Americas. + + .. container:: catalog-chip-row + + :bdg-secondary:`Geostationary Active Fire` :bdg-info:`Raster NetCDF time series` + + .. container:: catalog-meta-row + + **Coverage:** GOES-East and GOES-West full-disk views + + .. container:: catalog-meta-row + + **Update Cadence:** About every 10 minutes + + .. container:: catalog-meta-row + + **Inspection:** ``python -m pyhazards.datasets.goesr.inspection --path /home/runyang/ryang/GOES_FDCF_G16/2024 --max-items 10`` + + .. container:: catalog-meta-row + + **Related Benchmarks:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Details:** :doc:`GOES-R FDCF ` + + .. container:: catalog-link-row + + **Primary Source:** `GOES-R Fire/Hot Spot Characterization `_ + + .. grid-item-card:: WRC Housing Density + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + Housing-density raster from Wildfire Risk to Communities used for WUI and exposure-aware wildfire analysis. + + .. container:: catalog-chip-row + + :bdg-secondary:`Exposure Context` :bdg-info:`Raster exposure layers` + + .. container:: catalog-meta-row + + **Coverage:** United States + + .. container:: catalog-meta-row + + **Update Cadence:** Release-based + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/WRC_Housing_Density -maxdepth 3 -type f | head`` + + .. container:: catalog-meta-row + + **Related Benchmarks:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Details:** :doc:`WRC Housing Density ` + + .. container:: catalog-link-row + + **Primary Source:** `Wildfire Risk to Communities datasets `_ + + .. grid-item-card:: LandScan Population + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + Population raster used for population-at-risk and human exposure context in wildfire studies. + + .. container:: catalog-chip-row + + :bdg-secondary:`Population Exposure` :bdg-info:`Gridded population rasters` + + .. container:: catalog-meta-row + + **Coverage:** Global + + .. container:: catalog-meta-row + + **Update Cadence:** Release-based / annual + + .. container:: catalog-meta-row + + **Inspection:** ``find /home/runyang/ryang/LandScan_Global_2024 -maxdepth 3 -type f | head`` + + .. container:: catalog-meta-row + + **Related Benchmarks:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Details:** :doc:`LandScan Population ` + + .. container:: catalog-link-row + + **Primary Source:** `LandScan Global 2024 dataset entry `_ + .. tab-item:: Flood @@ -910,12 +1368,26 @@ model and evaluation coverage. datasets/era5 datasets/goesr datasets/merra2 + datasets/hpwren_weather + datasets/spot_forecast + datasets/nohrsc_snodas + datasets/hrrr + datasets/ndfd + datasets/goes_geocolor + datasets/nasa_gibs + datasets/synoptic_weather datasets/firms datasets/fpa_fod_tabular datasets/fpa_fod_weekly datasets/landfire datasets/mtbs datasets/wfigs + datasets/frap_fire_perimeters + datasets/geomac_historical + datasets/hms_smoke + datasets/goesr_fdcf + datasets/wrc_housing_density + datasets/landscan_population datasets/caravan_streamflow datasets/floodcastbench_inundation datasets/hydrobench_streamflow diff --git a/docs/source/pyhazards_models.rst b/docs/source/pyhazards_models.rst index fd46d953..ad8d7eb1 100644 --- a/docs/source/pyhazards_models.rst +++ b/docs/source/pyhazards_models.rst @@ -29,7 +29,7 @@ At a Glance .. container:: catalog-stat-value - 24 + 66 .. container:: catalog-stat-note @@ -51,7 +51,7 @@ At a Glance .. container:: catalog-stat-value - 27 + 61 .. container:: catalog-stat-note @@ -72,7 +72,7 @@ pages and compatible benchmark coverage. .. container:: catalog-section-note - Wildfire models cover danger forecasting, weekly activity forecasting, and spread prediction under the shared wildfire benchmark family. + Wildfire models now cover forecasting, spread prediction, operational detection, foundation-model transfer, and prompted multimodal reasoning under the shared wildfire benchmark family. .. rubric:: Implemented Models @@ -184,35 +184,39 @@ pages and compatible benchmark coverage. **Paper:** `ForeFire: A Modular, Scriptable C++ Simulation Engine and Library for Wildland-Fire Spread `_ | **Repo:** `Repository `__ - .. grid-item-card:: Wildfire Forecasting + .. grid-item-card:: WildfireSpreadTS :class-card: catalog-entry-card .. container:: catalog-entry-summary - A sequence forecasting baseline for next-window wildfire activity across weekly count features. + A temporal convolution wildfire spread baseline over short raster history windows. .. container:: catalog-chip-row - :bdg-primary:`Wildfire` :bdg-secondary:`Forecasting` :bdg-success:`Implemented` + :bdg-primary:`Wildfire` :bdg-secondary:`Spread` :bdg-success:`Implemented` .. container:: catalog-meta-row - **Details:** :doc:`Wildfire Forecasting ` + **Details:** :doc:`WildfireSpreadTS ` .. container:: catalog-meta-row **Benchmark Family:** :doc:`Wildfire Benchmark ` + .. container:: catalog-meta-row + + **Benchmark Ecosystems:** :doc:`WildfireSpreadTS ` + .. container:: catalog-link-row - **Paper:** `Wildfire Danger Prediction and Understanding with Deep Learning `_ | **Repo:** `Repository `__ + **Paper:** `WildfireSpreadTS: A Dataset of Multi-Modal Time Series for Wildfire Spread Prediction `_ | **Repo:** `Repository `__ - .. grid-item-card:: WildfireSpreadTS + .. grid-item-card:: WRF-SFIRE Adapter :class-card: catalog-entry-card .. container:: catalog-entry-summary - A temporal convolution wildfire spread baseline over short raster history windows. + A lightweight raster wildfire spread adapter inspired by WRF-SFIRE style transport. .. container:: catalog-chip-row @@ -220,7 +224,7 @@ pages and compatible benchmark coverage. .. container:: catalog-meta-row - **Details:** :doc:`WildfireSpreadTS ` + **Details:** :doc:`WRF-SFIRE Adapter ` .. container:: catalog-meta-row @@ -232,14 +236,14 @@ pages and compatible benchmark coverage. .. container:: catalog-link-row - **Paper:** `WildfireSpreadTS: A Dataset of Multi-Modal Time Series for Wildfire Spread Prediction `_ | **Repo:** `Repository `__ + **Paper:** `Coupled atmosphere-wildland fire modeling with WRF 3.3 and SFIRE 2011 `_ | **Repo:** `Repository `__ - .. grid-item-card:: WRF-SFIRE Adapter + .. grid-item-card:: CNN-ASPP :class-card: catalog-entry-card .. container:: catalog-entry-summary - A lightweight raster wildfire spread adapter inspired by WRF-SFIRE style transport. + An explainable CNN segmentation model with an ASPP mechanism for next-day wildfire spread prediction. .. container:: catalog-chip-row @@ -247,7 +251,7 @@ pages and compatible benchmark coverage. .. container:: catalog-meta-row - **Details:** :doc:`WRF-SFIRE Adapter ` + **Details:** :doc:`CNN-ASPP ` .. container:: catalog-meta-row @@ -259,14 +263,14 @@ pages and compatible benchmark coverage. .. container:: catalog-link-row - **Paper:** `Coupled atmosphere-wildland fire modeling with WRF 3.3 and SFIRE 2011 `_ | **Repo:** `Repository `__ + **Paper:** `Application of Explainable Artificial Intelligence in Predicting Wildfire Spread `_ - .. grid-item-card:: CNN-ASPP + .. grid-item-card:: FirePred :class-card: catalog-entry-card .. container:: catalog-entry-summary - An explainable CNN segmentation model with an ASPP mechanism for next-day wildfire spread prediction. + A hybrid multi-temporal CNN wildfire spread model over short satellite-history windows. .. container:: catalog-chip-row @@ -274,19 +278,798 @@ pages and compatible benchmark coverage. .. container:: catalog-meta-row - **Details:** :doc:`CNN-ASPP ` + **Details:** :doc:`FirePred ` .. container:: catalog-meta-row **Benchmark Family:** :doc:`Wildfire Benchmark ` + .. container:: catalog-link-row + + **Repo:** `Repository `__ + + .. grid-item-card:: FireMM-IR + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A benchmark-facing multi-modal large-model port for infrared-enhanced wildfire scene understanding. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`LLM / MLLM` :bdg-success:`Implemented` + .. container:: catalog-meta-row - **Benchmark Ecosystems:** :doc:`WildfireSpreadTS ` + **Details:** :doc:`FireMM-IR ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` .. container:: catalog-link-row - **Paper:** `Application of Explainable Artificial Intelligence in Predicting Wildfire Spread `_ + **Paper:** `FireMM-IR `_ + + .. grid-item-card:: MODIS Active Fire C6.1 + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + An operational-detection baseline inspired by NASA's MODIS Collection 6.1 active-fire algorithm. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Operational Detection` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`MODIS Active Fire C6.1 ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Paper:** `Giglio et al. (2016) `_ + + .. grid-item-card:: Prithvi-EO-2.0-TL + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A transfer-learning earth-observation foundation-model port for dense wildfire prediction. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Foundation Model` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Prithvi-EO-2.0-TL ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Model Card:** `IBM-NASA Prithvi-EO-2.0-300M-TL `_ + + .. grid-item-card:: Prithvi BurnScars + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A benchmark-facing burn-scar segmentation downstream model derived from the Prithvi EO family. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Foundation Model` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Prithvi BurnScars ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Model Card:** `Prithvi-EO-2.0-300M-BurnScars `_ + + .. grid-item-card:: Prithvi-WxC + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A weather-climate foundation-model port adapted for dense wildfire-risk prediction. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Foundation Model` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Prithvi-WxC ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Paper:** `Prithvi WxC `_ + + .. grid-item-card:: Gemini 2.5 Pro Wildfire Prompted + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A prompt-conditioned wildfire VLM baseline inspired by Gemini 2.5 Pro. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`LLM / MLLM` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Gemini 2.5 Pro Wildfire Prompted ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** `Gemini models documentation `_ + + .. grid-item-card:: InternVL3 Wildfire Prompted + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A prompt-conditioned wildfire VLM baseline inspired by InternVL3. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`LLM / MLLM` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`InternVL3 Wildfire Prompted ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Repo:** `Repository `__ + + .. grid-item-card:: Llama 4 Wildfire Prompted + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A prompt-conditioned multimodal wildfire baseline inspired by Meta Llama 4. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`LLM / MLLM` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Llama 4 Wildfire Prompted ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** `Meta Llama `_ + + .. grid-item-card:: Qwen2.5-VL Wildfire Prompted + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A prompt-conditioned wildfire VLM baseline inspired by Qwen2.5-VL. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`LLM / MLLM` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Qwen2.5-VL Wildfire Prompted ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Paper:** `Qwen2.5-VL Technical Report `_ + + .. grid-item-card:: TS-SatFire + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A spatio-temporal satellite wildfire benchmark model over multi-temporal raster sequences. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spread` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`TS-SatFire ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Paper:** `TS-SatFire `_ + + .. grid-item-card:: VIIRS 375 m Active Fire + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + An operational-detection baseline inspired by NASA's VIIRS 375 m active-fire algorithm. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Operational Detection` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`VIIRS 375 m Active Fire ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Paper:** `Schroeder et al. (2014) `_ + + .. grid-item-card:: WildfireGPT + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A benchmark-facing multi-agent wildfire reasoning model inspired by WildfireGPT. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`LLM / MLLM` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`WildfireGPT ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Repo:** `Repository `__ + + +.. grid-item-card:: Logistic Regression + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A classical binary wildfire occurrence baseline over tabular features. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Classification` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Logistic Regression ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: Random Forest + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A random-forest wildfire occurrence baseline over tabular features. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Classification` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Random Forest ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: XGBoost + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A boosted-tree wildfire occurrence baseline using a binary logistic objective. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Classification` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`XGBoost ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: LightGBM + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A boosted-tree wildfire occurrence baseline using LightGBM binary classification. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Classification` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`LightGBM ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: U-Net + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A compact dense-prediction wildfire baseline built on a U-Net style encoder-decoder. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Segmentation` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`U-Net ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: ResNet-18 U-Net + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A residual encoder-decoder wildfire segmentation baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Segmentation` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`ResNet-18 U-Net ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: Attention U-Net + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + An attention-gated U-Net wildfire segmentation baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Segmentation` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Attention U-Net ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: DeepLabv3+ + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A DeepLab-style wildfire segmentation baseline with ASPP-like context aggregation. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Segmentation` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`DeepLabv3+ ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: ConvLSTM + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A recurrent spatio-temporal wildfire prediction baseline over raster sequences. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`ConvLSTM ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: MAU + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A compact memory-augmented recurrent wildfire prediction baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`MAU ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: PredRNN-v2 + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A predictive recurrent wildfire raster baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`PredRNN-v2 ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: Rainformer + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A transformer-style spatio-temporal wildfire raster baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Rainformer ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: Earthformer + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A compact Earthformer-style wildfire forecasting baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Earthformer ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: SwinLSTM + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A windowed-attention recurrent wildfire raster baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`SwinLSTM ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: EarthFarseer + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A compact EarthFarseer-style wildfire forecasting baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`EarthFarseer ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: ConvGRU / TrajGRU + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A recurrent wildfire baseline mixing ConvGRU and TrajGRU-style dynamics. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`ConvGRU / TrajGRU ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: TCN + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A temporal convolution wildfire baseline over short raster histories. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`TCN ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: UTAE + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A temporal attention encoder wildfire baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Spatiotemporal` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`UTAE ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: SegFormer + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A transformer-based dense wildfire prediction baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Transformer` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`SegFormer ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: Swin-UNet + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A Swin-style encoder-decoder wildfire segmentation baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Transformer` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Swin-UNet ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: ViT Segmenter + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + A ViT-style dense wildfire segmentation baseline. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Transformer` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`ViT Segmenter ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. + +.. grid-item-card:: Deep Ensemble + :class-card: catalog-entry-card + + .. container:: catalog-entry-summary + + An ensemble wildfire segmentation baseline that averages multiple member networks. + + .. container:: catalog-chip-row + + :bdg-primary:`Wildfire` :bdg-secondary:`Uncertainty` :bdg-success:`Implemented` + + .. container:: catalog-meta-row + + **Details:** :doc:`Deep Ensemble ` + + .. container:: catalog-meta-row + + **Benchmark Family:** :doc:`Wildfire Benchmark ` + + .. container:: catalog-link-row + + **Source:** Promoted Track-O baseline implementation. .. tab-item:: Earthquake @@ -927,6 +1710,40 @@ before selecting a model for evaluation. modules/models_eqnet modules/models_eqtransformer modules/models_firecastnet + modules/models_logistic_regression + modules/models_random_forest + modules/models_xgboost + modules/models_lightgbm + modules/models_unet + modules/models_resnet18_unet + modules/models_attention_unet + modules/models_deeplabv3p + modules/models_convlstm + modules/models_mau + modules/models_predrnn_v2 + modules/models_rainformer + modules/models_earthformer + modules/models_swinlstm + modules/models_earthfarseer + modules/models_convgru_trajgru + modules/models_tcn + modules/models_utae + modules/models_segformer + modules/models_swin_unet + modules/models_vit_segmenter + modules/models_deep_ensemble + modules/models_firemm_ir + modules/models_firepred + modules/models_gemini_25_pro_wildfire_prompted + modules/models_internvl3_wildfire_prompted + modules/models_llama4_wildfire_prompted + modules/models_modis_active_fire_c61 + modules/models_prithvi_burnscars + modules/models_prithvi_eo_2_tl + modules/models_prithvi_wxc + modules/models_qwen25_vl_wildfire_prompted + modules/models_ts_satfire + modules/models_viirs_375m_active_fire modules/models_floodcast modules/models_forefire modules/models_fourcastnet_tc @@ -946,7 +1763,7 @@ before selecting a model for evaluation. modules/models_urbanfloodcast modules/models_wavecastnet modules/models_wildfire_aspp - modules/models_wildfire_forecasting + modules/models_wildfiregpt modules/models_wildfire_fpa modules/models_wildfirespreadts modules/models_wrf_sfire diff --git a/pyhazards/benchmark_cards/wildfire_benchmark.yaml b/pyhazards/benchmark_cards/wildfire_benchmark.yaml index e905aaf2..f965b5fa 100644 --- a/pyhazards/benchmark_cards/wildfire_benchmark.yaml +++ b/pyhazards/benchmark_cards/wildfire_benchmark.yaml @@ -37,7 +37,6 @@ smoke_configs: - pyhazards/configs/wildfire/firecastnet_smoke.yaml linked_models: - wildfire_fpa - - wildfire_forecasting - asufm - wildfire_aspp - wildfirespreadts @@ -46,3 +45,5 @@ linked_models: - firecastnet notes: - "WildfireSpreadTS is the public Appendix-A benchmark ecosystem surfaced on this page." + - "Run artifacts are organized under runs/wildfire_benchmark/{smoke,real,archive}." + - "The canonical experiment-setting schema lives in pyhazards/benchmarks/wildfire_benchmark/experiment_settings.py." diff --git a/pyhazards/benchmarks/__init__.py b/pyhazards/benchmarks/__init__.py index 2177b616..7a9b5fcc 100644 --- a/pyhazards/benchmarks/__init__.py +++ b/pyhazards/benchmarks/__init__.py @@ -4,6 +4,36 @@ from .schemas import BenchmarkResult, BenchmarkRunSummary from .earthquake import EarthquakeBenchmark from .wildfire import WildfireBenchmark +from .wildfire_benchmark import ( + AdapterRunOutput, + BenchmarkSection, + CacheBuildSummary, + align_static_fuel_to_cache, + REPRESENTATIVE_MODELS, + EvaluationProtocolSection, + ModelSection, + RunSection, + RunPaths, + WILDFIRE_BENCHMARK_CONFIG_ROOT, + WILDFIRE_RUNS_ROOT, + WildfireExperimentSetting, + WildfireSmokeAdapter, + SyntheticWildfireModelAdapter, + build_cache, + run_real_baselines, + build_default_experiment_setting, + write_experiment_setting, + prepare_run_paths, + build_experiment_setting_from_run_output, + build_model_template, + load_contract, + load_model_catalog, + parse_seed_list, + select_models, + run_smoke_batch, + create_adapter, + resolve_local_model_name, +) from .flood import FloodBenchmark from .tc import TropicalCycloneBenchmark @@ -15,6 +45,34 @@ "BenchmarkRunSummary", "TropicalCycloneBenchmark", "WildfireBenchmark", + "AdapterRunOutput", + "BenchmarkSection", + "CacheBuildSummary", + "align_static_fuel_to_cache", + "REPRESENTATIVE_MODELS", + "EvaluationProtocolSection", + "ModelSection", + "RunSection", + "RunPaths", + "WILDFIRE_BENCHMARK_CONFIG_ROOT", + "WILDFIRE_RUNS_ROOT", + "WildfireExperimentSetting", + "WildfireSmokeAdapter", + "SyntheticWildfireModelAdapter", + "build_cache", + "run_real_baselines", + "build_default_experiment_setting", + "write_experiment_setting", + "prepare_run_paths", + "build_experiment_setting_from_run_output", + "build_model_template", + "load_contract", + "load_model_catalog", + "parse_seed_list", + "select_models", + "run_smoke_batch", + "create_adapter", + "resolve_local_model_name", "available_benchmarks", "build_benchmark", "get_benchmark", diff --git a/pyhazards/benchmarks/wildfire_benchmark/REAL_DATA_2024_PLAN.md b/pyhazards/benchmarks/wildfire_benchmark/REAL_DATA_2024_PLAN.md new file mode 100644 index 00000000..3b175883 --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/REAL_DATA_2024_PLAN.md @@ -0,0 +1,286 @@ +# Wildfire Benchmark Real-Data Plan (2024 v1) + +## Goal + +Run the first real-data wildfire benchmark in `my-copy` using the 2024 data pack that is already available locally. + +This plan treats 2024 as the first stable benchmark year because: +- the local 2024 label and weather coverage is already present; +- the benchmark contract already uses 2024 splits; +- we can start with a fair, single-year method comparison before moving to 2025 generalization. + +## Benchmark Year + +- Benchmark year: `2024` +- Primary task: `Track-O` +- Task definition: predict grid-level or aggregated wildfire occurrence probability `P(y=1 | x)` + +## Real-Data 2024 Dataset Pack + +### Core inputs + +1. **Fire labels / fire history** +- Path: `/home/runyang/ryang/firms/combine` +- Format: daily CSV files such as `2024-01-10.csv` +- Role: primary occurrence label source and lagged fire-history source + +2. **Dynamic weather / land-surface forcing** +- Path: `/home/runyang/output2024` +- Format: Prithvi-WxC predicted NetCDF files such as `pred_20240101_18.nc` +- Available channels observed in sample files: + - `T2M`, `QV2M`, `TQV`, `U10M`, `V10M`, `GWETROOT`, `TS`, `LAI`, `EFLUX`, `HFLUX`, `SWGNT`, `SWTNT`, `LWGAB`, `LWGEM` +- Role: main dynamic feature source + +3. **Static fuels / vegetation** +- Path: `/home/runyang/ryang/landfire_fbfm40` +- Role: static fuel and vegetation background + +### Recommended v1 optional inputs + +4. **Perimeters** +- Path: `/home/runyang/ryang/WFIGS_Perimeters/history_2024` +- Role: event extent, perimeter-derived context, perimeter proximity features + +5. **Human activity proxy** +- Candidate paths: + - `/home/runyang/ryang/WRC_Housing_Density` + - `/home/runyang/ryang/LandScan_Global_2024` +- Role: ignition proxy and exposure context + +## Minimum Real-Data Feature Packs by Model Family + +### Classical / Trees +Use aggregated tabular features. + +Required: +- FIRMS labels / lagged fire counts +- aggregated weather features from `output2024` +- LANDFIRE static fuel features + +Recommended: +- WFIGS perimeter proximity +- housing / population + +### Deep Learning +Use raster or raster-sequence tensors. + +Required: +- FIRMS rasterized labels +- `output2024` weather tensors +- LANDFIRE static channels + +Recommended: +- fire-history channels from FIRMS +- WFIGS perimeter channels or masks + +### Satellite Remote Sensing +Use raster tensors with wildfire-specific spatial observations. + +Required for v1: +- FIRMS labels / fire-history +- `output2024` +- LANDFIRE + +Recommended for later v2: +- GOES FDCF +- HMS Smoke + +### Physics / Simulators +Required: +- weather +- fuels +- perimeter or ignition initialization + +This group should not block the first real-data benchmark if data conversion takes longer. + +### Foundation Models +- `prithvi_wxc`: prioritize weather sequence tensors from `output2024` +- `prithvi_eo_2_tl`, `prithvi_burnscars`: use raster sequences plus static channels and fire-history context + +### LLM / MLLM +Do not block v1 on raw NetCDF ingestion. +Use summarized products, rendered maps, metadata, and benchmark-derived inputs after the core benchmark is stable. + +## Data Processing Strategy + +### Step 1: Build a canonical benchmark grid and date index +- Use `output2024/pred_20240101_18.nc` as a canonical weather grid reference. +- Create a canonical daily date list from `2024-01-01` through `2024-12-31`. +- Align all dynamic inputs to that grid and daily calendar. + +### Step 2: Build labels +- Read FIRMS daily CSV files from `/home/runyang/ryang/firms/combine`. +- Rasterize or aggregate them onto the benchmark grid. +- Create: + - `y_t`: binary occurrence label for day `t` + - optional lagged fire-history channels from prior days + +### Step 3: Build dynamic weather tensors +- Read `output2024/pred_*.nc` +- Daily aggregate if multiple files per day are used +- Select the 14 current channels as the default dynamic pack + +### Step 4: Build static tensors +- Reproject or sample LANDFIRE fuels to the benchmark grid +- Add optional human-activity layers if needed + +### Step 5: Materialize cached benchmark-ready arrays +Recommended cache layout: + +```text +/home/runyang/my-copy/data_cache/wildfire_2024_v1/ + dates.txt + labels/ + 2024-01-01.npy + met/ + 2024-01-01.npy + static/ + fuel.npy + housing.npy + population.npy + metadata/ + grid.json + vars.json +``` + +## Train / Val / Test Protocol + +Use the current benchmark contract split: +- Train: `2024-01-01` to `2024-09-30` +- Val: `2024-10-01` to `2024-10-31` +- Test: `2024-11-01` to `2024-12-31` + +Rules: +- fit all normalization statistics on train only; +- no future covariates relative to the prediction target; +- store fixed split files for reproducibility. + +## Training Recommendations + +### Phase A: real-data dry run +Use one seed first. +- Seed: `42` +- Purpose: verify data loading, training loop, output schema, and metric computation + +### Phase B: final benchmark runs +Use multi-seed reporting. +- Seeds: `42, 52, 62, 72, 82` +- Report: `mean ± std` + +### Classical / Trees +- `logistic_regression`: native binary objective +- `random_forest`: `predict_proba` +- `xgboost`: binary objective, several hundred rounds allowed +- `lightgbm`: binary objective, several hundred rounds allowed + +### Deep models +- task: binary occurrence probability +- output: one logit per grid cell / tile / target unit +- loss: `BCEWithLogitsLoss` +- recommended initial schedule: + - `max_epochs = 120` or higher + - early stopping monitor: `val_auprc` + - `patience = 20 to 30` + - `min_delta = 1e-4` +- current smoke settings are not sufficient for convergence claims + +## GPU Policy + +Real-data deep-model training should use GPU, not CPU. + +Record in `experiment_setting.json`: +- device +- gpu id +- gpu name +- total memory if available + +Recommended policy: +- classical models may remain on CPU unless GPU versions are explicitly used +- deep models should default to `cuda:` + +## Output Layout + +All new real-data benchmark artifacts should be written under: + +```text +/home/runyang/my-copy/runs/wildfire_benchmark/real/ +``` + +Recommended run layout: + +```text +runs/wildfire_benchmark/real/track_o_2024_real_v1/ + benchmark_contract_snapshot.json + benchmark_summary.json + experiment_templates.json + / + model_template.json + model_summary.json + seed_42/ + experiment_setting.json + history.csv + loss_curve.png + metrics.json +``` + +## Required Per-Seed Outputs + +For every model and seed: +- `experiment_setting.json` +- `history.csv` +- `loss_curve.png` +- `metrics.json` + +### history.csv should include +At minimum: +- step column (`epoch`, `round`, `iteration`, or `tree_count`) +- `train_loss` +- `val_loss` +- optional learning-rate column when applicable + +### loss_curve.png should show +- train loss vs step +- validation loss vs step +- clear title with model name and train unit + +## Evaluation Protocol + +### Primary metrics +- `AUPRC` + +### Secondary metrics +- `AUROC` + +### Reliability metrics +- `Brier` +- `NLL` +- `ECE` + +### Temporal consistency metrics +- `mean_day_to_day_change` +- `normalized_consistency_score` + +### Reporting rules +- report mean and std across seeds for final benchmark numbers +- include train/val loss curves +- log best step +- log converged step + +## Recommended Execution Order + +1. Build cache from FIRMS + `output2024` + LANDFIRE +2. Run `seed=42` dry run on 4 representative models: + - `logistic_regression` + - `xgboost` + - `unet` + - `convlstm` +3. Validate output artifacts and metric computation +4. Expand to the rest of the main benchmark roster +5. Add remote-sensing / foundation / simulator tracks afterwards + +## Immediate Implementation Notes + +- The current `track_o_2024_v1.json` still points to `/home/runyang/ryang/firms_download/combine`. +- The locally verified combined FIRMS label directory is `/home/runyang/ryang/firms/combine`. +- The first real-data contract should use the verified local path. + diff --git a/pyhazards/benchmarks/wildfire_benchmark/__init__.py b/pyhazards/benchmarks/wildfire_benchmark/__init__.py new file mode 100644 index 00000000..91cddac8 --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/__init__.py @@ -0,0 +1,47 @@ +from .experiment_settings import ( + BenchmarkSection, + EvaluationProtocolSection, + ModelSection, + RunSection, + WildfireExperimentSetting, + build_default_experiment_setting, + write_experiment_setting, +) +from .layout import RunPaths, WILDFIRE_RUNS_ROOT, prepare_run_paths +from .artifacts import AdapterRunOutput, build_experiment_setting_from_run_output, build_model_template +from .catalog import WILDFIRE_BENCHMARK_CONFIG_ROOT, load_contract, load_model_catalog, parse_seed_list, select_models +from .runner import run_smoke_batch +from .cache_builder import CacheBuildSummary, align_static_fuel_to_cache, build_cache +from .real_runner import REPRESENTATIVE_MODELS, run_real_baselines +from .adapters import WildfireSmokeAdapter, SyntheticWildfireModelAdapter, create_adapter, resolve_local_model_name + +__all__ = [ + "AdapterRunOutput", + "BenchmarkSection", + "CacheBuildSummary", + "align_static_fuel_to_cache", + "REPRESENTATIVE_MODELS", + "EvaluationProtocolSection", + "ModelSection", + "RunSection", + "RunPaths", + "WILDFIRE_BENCHMARK_CONFIG_ROOT", + "WILDFIRE_RUNS_ROOT", + "WildfireExperimentSetting", + "WildfireSmokeAdapter", + "SyntheticWildfireModelAdapter", + "build_cache", + "run_real_baselines", + "build_default_experiment_setting", + "write_experiment_setting", + "prepare_run_paths", + "build_experiment_setting_from_run_output", + "build_model_template", + "load_contract", + "load_model_catalog", + "parse_seed_list", + "select_models", + "run_smoke_batch", + "create_adapter", + "resolve_local_model_name", +] diff --git a/pyhazards/benchmarks/wildfire_benchmark/adapters/__init__.py b/pyhazards/benchmarks/wildfire_benchmark/adapters/__init__.py new file mode 100644 index 00000000..2229d760 --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/adapters/__init__.py @@ -0,0 +1,10 @@ +from .base import WildfireSmokeAdapter +from .registry import create_adapter +from .synthetic import SyntheticWildfireModelAdapter, resolve_local_model_name + +__all__ = [ + "WildfireSmokeAdapter", + "SyntheticWildfireModelAdapter", + "create_adapter", + "resolve_local_model_name", +] diff --git a/pyhazards/benchmarks/wildfire_benchmark/adapters/base.py b/pyhazards/benchmarks/wildfire_benchmark/adapters/base.py new file mode 100644 index 00000000..511c657d --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/adapters/base.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import hashlib +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from ..artifacts import AdapterRunOutput, STEP_LABEL, build_experiment_setting_from_run_output + + +class WildfireSmokeAdapter(ABC): + """Minimal adapter contract for my-copy wildfire benchmark smoke runs.""" + + def __init__(self, model_spec: Dict[str, Any], contract: Dict[str, Any], step_limits: Dict[str, int]): + self.model_spec = model_spec + self.contract = contract + self.step_limits = step_limits + + @property + def model_name(self) -> str: + return str(self.model_spec["name"]) + + @property + def train_unit(self) -> str: + return str(self.model_spec["train_unit"]) + + @property + def step_name(self) -> str: + return STEP_LABEL[self.train_unit] + + def resolve_num_steps(self) -> int: + defaults = self.model_spec.get("defaults", {}) + if self.train_unit == "epoch": + return max(5, min(int(defaults.get("max_epochs", self.step_limits["epoch"])), self.step_limits["epoch"])) + if self.train_unit == "round": + return max(10, min(int(defaults.get("num_boost_round", self.step_limits["round"])), self.step_limits["round"])) + if self.train_unit == "iteration": + return max(10, min(int(defaults.get("max_iter", self.step_limits["iteration"])), self.step_limits["iteration"])) + if self.train_unit == "tree": + return max(10, min(int(defaults.get("n_estimators", self.step_limits["tree"])), self.step_limits["tree"])) + raise ValueError(f"Unsupported train_unit={self.train_unit}") + + @abstractmethod + def run(self, seed: int) -> AdapterRunOutput: + """Run one smoke seed and return standardized benchmark artifacts.""" + + def build_experiment_setting(self, seed: int, run_output: AdapterRunOutput) -> Dict[str, Any]: + return build_experiment_setting_from_run_output( + contract=self.contract, + model_spec=self.model_spec, + seed=int(seed), + run_output=run_output, + ) + + +def stable_seed_offset(model_name: str) -> int: + digest = hashlib.sha256(model_name.encode("utf-8")).hexdigest()[:8] + return int(digest, 16) + + +def moving_average(values: List[float], window: int) -> List[float]: + if window <= 1: + return values[:] + out: List[float] = [] + for idx in range(len(values)): + start = max(0, idx - window + 1) + chunk = values[start : idx + 1] + out.append(float(sum(chunk) / len(chunk))) + return out + + +def find_converged_step(history: List[Dict[str, float]], train_unit: str, smooth_window: int, patience: int, min_improvement: float) -> int: + step_key = STEP_LABEL[train_unit] + val_loss = [float(row["val_loss"]) for row in history] + smoothed = moving_average(val_loss, smooth_window) + + stable = 0 + for idx in range(1, len(smoothed)): + improvement = smoothed[idx - 1] - smoothed[idx] + if improvement < min_improvement: + stable += 1 + else: + stable = 0 + if stable >= patience: + return int(history[idx][step_key]) + return int(history[-1][step_key]) + + +def normalized_consistency_score(mean_day_to_day_change: float) -> float: + return max(0.0, min(1.0, 1.0 - float(mean_day_to_day_change))) diff --git a/pyhazards/benchmarks/wildfire_benchmark/adapters/registry.py b/pyhazards/benchmarks/wildfire_benchmark/adapters/registry.py new file mode 100644 index 00000000..f56a1eea --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/adapters/registry.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from typing import Any, Dict + +from .base import WildfireSmokeAdapter +from .synthetic import SyntheticWildfireModelAdapter + + +SMOKE_ADAPTERS = {} + + +def create_adapter(model_spec: Dict[str, Any], contract: Dict[str, Any], step_limits: Dict[str, int]) -> WildfireSmokeAdapter: + adapter_cls = SMOKE_ADAPTERS.get(str(model_spec["name"]), SyntheticWildfireModelAdapter) + return adapter_cls(model_spec=model_spec, contract=contract, step_limits=step_limits) diff --git a/pyhazards/benchmarks/wildfire_benchmark/adapters/synthetic.py b/pyhazards/benchmarks/wildfire_benchmark/adapters/synthetic.py new file mode 100644 index 00000000..9dedf286 --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/adapters/synthetic.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import importlib +import math +from pathlib import Path +from typing import Dict, List + +import numpy as np + +from pyhazards.models import available_models + +from ..artifacts import AdapterRunOutput +from .base import ( + WildfireSmokeAdapter, + find_converged_step, + normalized_consistency_score, + stable_seed_offset, +) + +MODEL_NAME_ALIASES = { + "wrf_sfire_adapter": "wrf_sfire", + "forefire_adapter": "forefire", +} + + +def resolve_local_model_name(model_name: str) -> str: + return MODEL_NAME_ALIASES.get(model_name, model_name) + + +def summarize_metrics(best_val_loss: float, seed: int, model_name: str) -> Dict[str, float]: + seed_offset = stable_seed_offset(model_name) // 17 + rng = np.random.default_rng(seed + seed_offset) + + quality = float(np.clip(1.0 / (1.0 + best_val_loss), 0.0, 1.0)) + auprc = float(np.clip(0.03 + 0.5 * quality + rng.normal(0.0, 0.015), 0.0, 1.0)) + auroc = float(np.clip(0.55 + 0.42 * quality + rng.normal(0.0, 0.01), 0.0, 1.0)) + brier = float(np.clip(0.42 - 0.28 * quality + rng.normal(0.0, 0.01), 0.0, 1.0)) + nll = float(np.clip(1.1 - 0.7 * quality + rng.normal(0.0, 0.03), 0.01, 5.0)) + ece = float(np.clip(0.22 - 0.14 * quality + rng.normal(0.0, 0.008), 0.0, 1.0)) + temporal_delta = float(np.clip(0.25 - 0.12 * quality + rng.normal(0.0, 0.01), 0.0, 1.0)) + + return { + "auprc": auprc, + "auroc": auroc, + "brier": brier, + "nll": nll, + "ece": ece, + "mean_day_to_day_change": temporal_delta, + "normalized_consistency_score": normalized_consistency_score(temporal_delta), + } + + +class SyntheticWildfireModelAdapter(WildfireSmokeAdapter): + """Smoke adapter for migrated my-copy wildfire benchmark models.""" + + def _simulate_history(self, seed: int, num_steps: int) -> List[Dict[str, float]]: + step_key = self.step_name + defaults = self.model_spec.get("defaults", {}) + + seed_offset = stable_seed_offset(self.model_name) + rng = np.random.default_rng(seed + seed_offset) + + base_lr = float(defaults.get("lr", defaults.get("learning_rate", defaults.get("eta", 1e-3)))) + start_loss = float(rng.uniform(0.8, 1.6)) + floor_loss = float(rng.uniform(0.08, 0.25)) + speed = float(rng.uniform(0.02, 0.08)) + + history: List[Dict[str, float]] = [] + for step in range(1, num_steps + 1): + decay = floor_loss + (start_loss - floor_loss) * math.exp(-speed * step) + noise = float(rng.normal(0.0, 0.01)) + train_loss = max(0.01, decay + noise) + + gap = float(rng.uniform(0.02, 0.09)) + val_noise = float(rng.normal(0.0, 0.008)) + val_loss = max(0.01, train_loss + gap + val_noise) + + cosine = 0.5 * (1.0 + math.cos(math.pi * (step - 1) / max(1, num_steps - 1))) + learning_rate = base_lr * cosine + + history.append( + { + step_key: float(step), + "train_loss": float(train_loss), + "val_loss": float(val_loss), + "learning_rate": float(learning_rate), + } + ) + return history + + def _resolve_model_metadata(self) -> Dict[str, object]: + local_name = resolve_local_model_name(self.model_name) + registered = local_name in set(available_models()) + source_path = None + try: + module = importlib.import_module(f"pyhazards.models.{local_name}") + source_path = str(Path(module.__file__).resolve()) if getattr(module, "__file__", None) else None + except Exception: + source_path = None + return { + "canonical_model_name": local_name, + "registered_in_my_copy": registered, + "model_source": source_path, + } + + def run(self, seed: int) -> AdapterRunOutput: + num_steps = self.resolve_num_steps() + history = self._simulate_history(seed=seed, num_steps=num_steps) + + val_loss = [float(item["val_loss"]) for item in history] + best_idx = int(np.argmin(np.asarray(val_loss))) + best_step = int(history[best_idx][self.step_name]) + + conv_cfg = self.contract["shared_training"]["convergence_rule"] + converged_step = find_converged_step( + history=history, + train_unit=self.train_unit, + smooth_window=int(conv_cfg["smoothing_window"]), + patience=int(conv_cfg["patience"]), + min_improvement=float(conv_cfg["min_improvement"]), + ) + + metrics = summarize_metrics(best_val_loss=float(val_loss[best_idx]), seed=seed, model_name=self.model_name) + model_meta = self._resolve_model_metadata() + + return AdapterRunOutput( + history=history, + metrics=metrics, + best_step=best_step, + converged_step=converged_step, + train_unit=self.train_unit, + notes={ + "adapter_kind": "synthetic_my_copy_benchmark", + "status": "smoke_only", + "message": "Synthetic smoke run executed inside my-copy wildfire benchmark skeleton.", + **model_meta, + }, + ) diff --git a/pyhazards/benchmarks/wildfire_benchmark/artifacts.py b/pyhazards/benchmarks/wildfire_benchmark/artifacts.py new file mode 100644 index 00000000..6535b3db --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/artifacts.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import csv +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Mapping + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +from .experiment_settings import build_default_experiment_setting + +STEP_LABEL = { + "epoch": "epoch", + "round": "round", + "iteration": "iteration", + "tree": "tree_count", +} + + +@dataclass +class AdapterRunOutput: + history: List[Dict[str, float]] + metrics: Dict[str, float] + best_step: int + converged_step: int + train_unit: str + notes: Dict[str, Any] + + +def mean_std(values: List[float]) -> Dict[str, float]: + arr = np.asarray(values, dtype=float) + return {"mean": float(np.mean(arr)), "std": float(np.std(arr, ddof=0))} + + +def write_json(path: Path, payload: Mapping[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(dict(payload), indent=2), encoding="utf-8") + + +def write_history_csv(path: Path, rows: List[Dict[str, float]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if not rows: + path.write_text("", encoding="utf-8") + return + fieldnames = list(rows[0].keys()) + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +def plot_loss_curve(history: List[Dict[str, float]], train_unit: str, output_png: Path, title: str) -> None: + output_png.parent.mkdir(parents=True, exist_ok=True) + if not history: + return + step_key = STEP_LABEL[train_unit] + x = [int(row[step_key]) for row in history] + y_tr = [float(row["train_loss"]) for row in history] + y_va = [float(row["val_loss"]) for row in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, marker="o", linewidth=1.6, label="train_loss") + plt.plot(x, y_va, marker="s", linewidth=1.4, label="val_loss") + plt.xlabel(step_key) + plt.ylabel("loss") + plt.title(title) + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_png, dpi=150) + plt.close() + + +def build_model_template(contract: Dict[str, Any], model_spec: Dict[str, Any]) -> Dict[str, Any]: + shared = contract["shared_training"] + seed_list = shared.get("seed_list") or shared.get("dry_run_seed_list") or shared.get("final_seed_list") or [42] + return { + "template_version": "track_o_model_template_v1", + "benchmark_name": contract["benchmark_name"], + "contract_version": contract["contract_version"], + "task": contract["task"], + "model": { + "name": model_spec["name"], + "display_name": model_spec["display_name"], + "group": model_spec["group"], + "source_tier": model_spec["source_tier"], + "train_unit": model_spec["train_unit"], + "defaults": model_spec.get("defaults", {}), + }, + "reproducibility": { + "seed_list": seed_list, + "must_report_mean_std": contract["shared_training"]["report_requirements"]["report_mean_std_across_seeds"], + }, + "expected_metrics": [ + "auprc", + "auroc", + "brier", + "nll", + "ece", + "mean_day_to_day_change", + "normalized_consistency_score", + ], + "required_fields_for_real_runs": [ + "repo_url", + "repo_commit_or_tag", + "data_version", + "split_version", + "feature_set_version", + "hyperparam_search_budget", + "hardware", + "software_versions", + ], + } + + +def build_experiment_setting_from_run_output( + *, + contract: Dict[str, Any], + model_spec: Dict[str, Any], + seed: int, + run_output: AdapterRunOutput, +) -> Dict[str, Any]: + setting = build_default_experiment_setting( + model_name=str(model_spec["name"]), + display_name=str(model_spec["display_name"]), + group=str(model_spec["group"]), + source_tier=str(model_spec["source_tier"]), + train_unit=str(model_spec["train_unit"]), + defaults=model_spec.get("defaults", {}), + seed=int(seed), + num_steps=len(run_output.history), + best_step=int(run_output.best_step), + converged_step=int(run_output.converged_step), + step_name=STEP_LABEL[str(model_spec["train_unit"])], + mode=str(contract["mode"]), + task=str(contract["task"]), + notes=dict(run_output.notes), + metrics=run_output.metrics, + ) + setting.run.learning_weight = { + "kind": contract["shared_training"]["class_imbalance"]["policy"], + "value": "to_be_computed_from_real_train_split", + "clip_max": contract["shared_training"]["class_imbalance"]["clip_max"], + } + return setting.to_dict() diff --git a/pyhazards/benchmarks/wildfire_benchmark/cache_builder.py b/pyhazards/benchmarks/wildfire_benchmark/cache_builder.py new file mode 100644 index 00000000..65ef8b5d --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/cache_builder.py @@ -0,0 +1,378 @@ +from __future__ import annotations + +import json +import re +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence, Set + +import numpy as np +import pandas as pd +import xarray as xr +import yaml + +_DATE_RE = re.compile(r"pred_(\d{8})_\d{2}\.nc$") + + +@dataclass +class CacheBuildSummary: + cache_root: Path + n_label_days: int + n_met_days: int + n_shared_days: int + weather_vars: List[str] + train_days: int + val_days: int + test_days: int + + def to_dict(self) -> Dict[str, Any]: + return { + "cache_root": str(self.cache_root), + "n_label_days": int(self.n_label_days), + "n_met_days": int(self.n_met_days), + "n_shared_days": int(self.n_shared_days), + "weather_vars": list(self.weather_vars), + "train_days": int(self.train_days), + "val_days": int(self.val_days), + "test_days": int(self.test_days), + } + + +def _read_yaml(path: str | Path) -> Dict[str, Any]: + return yaml.safe_load(Path(path).read_text(encoding="utf-8")) + + +def _extract_pred_date(path: Path) -> str | None: + match = _DATE_RE.search(path.name) + if not match: + return None + stamp = match.group(1) + return f"{stamp[:4]}-{stamp[4:6]}-{stamp[6:8]}" + + +def _load_grid(sample_nc: Path) -> tuple[np.ndarray, np.ndarray]: + ds = xr.open_dataset(sample_nc) + try: + lat = np.asarray(ds["lat"].values, dtype=np.float64) + lon = np.asarray(ds["lon"].values, dtype=np.float64) + finally: + ds.close() + return lat, lon + + +def _select_weather_vars(ds: xr.Dataset, weather_vars: Sequence[str]) -> xr.Dataset: + missing = [name for name in weather_vars if name not in ds.data_vars] + if missing: + raise KeyError(f"Missing weather variables in dataset: {missing}") + return ds[list(weather_vars)] + + +def _discover_weather_groups(weather_dir: Path, weather_glob: str) -> Dict[str, List[Path]]: + grouped: Dict[str, List[Path]] = {} + for path in sorted(weather_dir.glob(weather_glob)): + date = _extract_pred_date(path) + if date is None: + continue + grouped.setdefault(date, []).append(path) + return grouped + + +def _discover_label_paths(firms_dir: Path, year: int) -> Dict[str, Path]: + return {path.stem: path for path in sorted(firms_dir.glob(f"{year}-*.csv"))} + + +def _daily_weather_arrays( + weather_groups: Dict[str, List[Path]], + weather_vars: Sequence[str], + *, + allowed_dates: Set[str] | None = None, +) -> Dict[str, np.ndarray]: + out: Dict[str, np.ndarray] = {} + for date in sorted(weather_groups): + if allowed_dates is not None and date not in allowed_dates: + continue + stacks: List[np.ndarray] = [] + for path in weather_groups[date]: + ds = xr.open_dataset(path) + try: + picked = _select_weather_vars(ds, weather_vars) + arr = np.stack([np.asarray(picked[var].values, dtype=np.float32) for var in weather_vars], axis=0) + if arr.ndim == 4 and arr.shape[1] == 1: + arr = arr[:, 0, :, :] + stacks.append(arr) + finally: + ds.close() + if stacks: + out[date] = np.mean(np.stack(stacks, axis=0), axis=0).astype(np.float32) + return out + + +def _read_firms_csv(path: Path) -> pd.DataFrame: + return pd.read_csv(path) + + +def _nearest_index(sorted_values: np.ndarray, values: np.ndarray) -> np.ndarray: + idx = np.searchsorted(sorted_values, values) + idx = np.clip(idx, 0, len(sorted_values) - 1) + left = np.clip(idx - 1, 0, len(sorted_values) - 1) + choose_left = np.abs(sorted_values[left] - values) <= np.abs(sorted_values[idx] - values) + return np.where(choose_left, left, idx) + + +def _firms_to_binary_grid(df: pd.DataFrame, lat: np.ndarray, lon: np.ndarray) -> np.ndarray: + label = np.zeros((lat.size, lon.size), dtype=np.float32) + if df.empty: + return label + + if "latitude" not in df.columns or "longitude" not in df.columns: + raise KeyError("FIRMS CSV must include 'latitude' and 'longitude' columns.") + + lat_vals = df["latitude"].to_numpy(dtype=np.float64, copy=False) + lon_vals = df["longitude"].to_numpy(dtype=np.float64, copy=False) + valid = np.isfinite(lat_vals) & np.isfinite(lon_vals) + lat_vals = lat_vals[valid] + lon_vals = lon_vals[valid] + if lat_vals.size == 0: + return label + + lat_idx = _nearest_index(lat, lat_vals) + lon_idx = _nearest_index(lon, lon_vals) + label[lat_idx, lon_idx] = 1.0 + return label + + +def _daily_label_arrays( + label_paths: Dict[str, Path], + lat: np.ndarray, + lon: np.ndarray, + *, + allowed_dates: Set[str] | None = None, +) -> Dict[str, np.ndarray]: + out: Dict[str, np.ndarray] = {} + for date in sorted(label_paths): + if allowed_dates is not None and date not in allowed_dates: + continue + df = _read_firms_csv(label_paths[date]) + out[date] = _firms_to_binary_grid(df, lat=lat, lon=lon) + return out + + +def _write_lines(path: Path, items: Iterable[str]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("\n".join(items), encoding="utf-8") + + +def _date_in_range(date: str, start: str, end: str) -> bool: + return start <= date <= end + + +def _write_split_files(cache_root: Path, dates: Sequence[str], split_cfg: Dict[str, Sequence[str]]) -> Dict[str, int]: + split_root = cache_root / "splits" + counts: Dict[str, int] = {} + for split_name in ("train", "val", "test"): + start, end = split_cfg[split_name] + split_dates = [d for d in dates if _date_in_range(d, str(start), str(end))] + _write_lines(split_root / f"{split_name}_dates.txt", split_dates) + counts[split_name] = len(split_dates) + return counts + + +def build_cache(config_path: str | Path, *, limit_days: int = 0) -> CacheBuildSummary: + cfg = _read_yaml(config_path) + + cache_root = Path(cfg["cache"]["root"]) + labels_dir = cache_root / "labels" + met_dir = cache_root / "met" + static_dir = cache_root / "static" + metadata_dir = cache_root / "metadata" + for path in (labels_dir, met_dir, static_dir, metadata_dir): + path.mkdir(parents=True, exist_ok=True) + + weather_dir = Path(cfg["data"]["weather_dir"]) + weather_glob = str(cfg["data"].get("weather_glob", "pred_2024*.nc")) + weather_vars = list(cfg["data"]["weather_vars"]) + sample_nc = weather_dir / str(cfg["data"].get("sample_nc", "pred_20240101_18.nc")) + firms_dir = Path(cfg["data"]["firms_daily_dir"]) + landfire_tif = Path(cfg["data"]["landfire_tif"]) + year = int(cfg["data"]["year"]) + + lat, lon = _load_grid(sample_nc) + np.save(metadata_dir / "lat.npy", lat.astype(np.float32)) + np.save(metadata_dir / "lon.npy", lon.astype(np.float32)) + (metadata_dir / "grid.json").write_text( + json.dumps( + { + "sample_nc": str(sample_nc), + "lat_size": int(lat.size), + "lon_size": int(lon.size), + "lat_min": float(lat.min()), + "lat_max": float(lat.max()), + "lon_min": float(lon.min()), + "lon_max": float(lon.max()), + }, + indent=2, + ), + encoding="utf-8", + ) + (metadata_dir / "vars.json").write_text(json.dumps({"weather_vars": weather_vars}, indent=2), encoding="utf-8") + + weather_groups = _discover_weather_groups(weather_dir, weather_glob) + label_paths = _discover_label_paths(firms_dir, year) + + candidate_shared_dates = sorted(set(weather_groups) & set(label_paths)) + if limit_days > 0: + candidate_shared_dates = candidate_shared_dates[:limit_days] + allowed_dates = set(candidate_shared_dates) + + met_arrays = _daily_weather_arrays(weather_groups, weather_vars, allowed_dates=allowed_dates) + label_arrays = _daily_label_arrays(label_paths, lat=lat, lon=lon, allowed_dates=allowed_dates) + + shared_dates = sorted(set(met_arrays) & set(label_arrays)) + + for date in shared_dates: + np.save(met_dir / f"{date}.npy", met_arrays[date]) + np.save(labels_dir / f"{date}.npy", label_arrays[date]) + + _write_lines(cache_root / "dates.txt", shared_dates) + + static_manifest = { + "fuel_source": str(landfire_tif), + "status": "source_registered_only", + "message": "Static fuel reprojection to the benchmark grid is deferred until rasterio/rioxarray are available.", + "expected_output_path": str(static_dir / "fuel.npy"), + } + (static_dir / "fuel_manifest.json").write_text(json.dumps(static_manifest, indent=2), encoding="utf-8") + + split_counts = _write_split_files(cache_root, shared_dates, cfg["splits"]) + + summary = CacheBuildSummary( + cache_root=cache_root, + n_label_days=len(label_paths), + n_met_days=len(weather_groups), + n_shared_days=len(shared_dates), + weather_vars=weather_vars, + train_days=split_counts["train"], + val_days=split_counts["val"], + test_days=split_counts["test"], + ) + (cache_root / "cache_summary.json").write_text(json.dumps(summary.to_dict(), indent=2), encoding="utf-8") + return summary + + +def align_static_fuel_to_cache( + cache_root: str | Path, + *, + landfire_tif: str | Path | None = None, + overwrite: bool = False, +) -> Dict[str, Any]: + import tifffile as tf + + cache_root = Path(cache_root) + static_dir = cache_root / "static" + metadata_dir = cache_root / "metadata" + static_dir.mkdir(parents=True, exist_ok=True) + + lat = np.load(metadata_dir / "lat.npy") + lon = np.load(metadata_dir / "lon.npy") + manifest_path = static_dir / "fuel_manifest.json" + manifest = {} + if manifest_path.exists(): + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + + source_path = Path(landfire_tif or manifest.get("fuel_source") or "") + if not str(source_path): + raise ValueError("LANDFIRE source path is required to align static fuel to the cache grid.") + if not source_path.exists(): + raise FileNotFoundError(f"LANDFIRE source not found: {source_path}") + + fuel_npy_path = static_dir / "fuel.npy" + fuel_mask_path = static_dir / "fuel_mask.npy" + aligned_tif_path = static_dir / "fuel_aligned_benchmark_grid.tif" + if fuel_npy_path.exists() and fuel_mask_path.exists() and not overwrite: + payload = json.loads(manifest_path.read_text(encoding="utf-8")) if manifest_path.exists() else {} + return payload + + width = int(lon.size) + height = int(lat.size) + cmd = [ + "gdalwarp", + "-overwrite", + "-multi", + "-wo", + "NUM_THREADS=ALL_CPUS", + "-t_srs", + "EPSG:4326", + "-te", + "-180", + "-90", + "180", + "90", + "-ts", + str(width), + str(height), + "-r", + "mode", + "-srcnodata", + "32767", + "-dstnodata", + "-9999", + str(source_path), + str(aligned_tif_path), + ] + subprocess.run(cmd, check=True) + + raw = tf.imread(aligned_tif_path) + if raw.shape != (height, width): + raise ValueError(f"Aligned fuel raster shape mismatch: expected {(height, width)}, got {raw.shape}") + + valid_mask = raw >= 0 + fuel = np.where(valid_mask, raw, 0).astype(np.int16) + fuel_mask = valid_mask.astype(np.uint8) + np.save(fuel_npy_path, fuel) + np.save(fuel_mask_path, fuel_mask) + + unique_valid = np.unique(fuel[valid_mask]) if np.any(valid_mask) else np.asarray([], dtype=np.int16) + payload = { + "fuel_source": str(source_path), + "status": "aligned_to_cache_grid", + "grid_shape": [height, width], + "warp": { + "target_srs": "EPSG:4326", + "target_extent": [-180, -90, 180, 90], + "target_size": [width, height], + "resampling": "mode", + "dst_nodata": -9999, + }, + "output_files": { + "fuel": str(fuel_npy_path), + "fuel_mask": str(fuel_mask_path), + "aligned_tif": str(aligned_tif_path), + }, + "valid_cells": int(valid_mask.sum()), + "valid_fraction": float(valid_mask.mean()), + "unique_valid_values_count": int(unique_valid.size), + "unique_valid_values_sample": unique_valid[:32].astype(int).tolist(), + "notes": [ + "Static fuel values were warped from LANDFIRE CONUS Albers to the benchmark's nominal global lat-lon grid.", + "Negative values were treated as outside-domain/no-data and written as fuel=0 with fuel_mask=0.", + ], + } + manifest_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + summary_path = cache_root / "cache_summary.json" + if summary_path.exists(): + summary = json.loads(summary_path.read_text(encoding="utf-8")) + summary["static_fuel"] = { + "status": "aligned", + "fuel_file": str(fuel_npy_path), + "fuel_mask_file": str(fuel_mask_path), + "valid_cells": int(valid_mask.sum()), + "valid_fraction": float(valid_mask.mean()), + } + summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + + return payload + + +__all__ = ["CacheBuildSummary", "build_cache", "align_static_fuel_to_cache"] diff --git a/pyhazards/benchmarks/wildfire_benchmark/catalog.py b/pyhazards/benchmarks/wildfire_benchmark/catalog.py new file mode 100644 index 00000000..d2edd379 --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/catalog.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List + +WILDFIRE_BENCHMARK_CONFIG_ROOT = Path(__file__).resolve().parents[2] / "configs" / "wildfire_benchmark" + + +def load_json(path: Path) -> Any: + return json.loads(path.read_text(encoding="utf-8")) + + +def load_contract(path: str | Path | None = None) -> Dict[str, Any]: + target = Path(path) if path else WILDFIRE_BENCHMARK_CONFIG_ROOT / "track_o_2024_v1.json" + return load_json(target) + + +def load_model_catalog(kind: str = "main", path: str | Path | None = None) -> List[Dict[str, Any]]: + if path is not None: + return load_json(Path(path)) + filename = "model_catalog_22.json" if kind == "main" else "model_catalog_extensions_v1.json" + return load_json(WILDFIRE_BENCHMARK_CONFIG_ROOT / filename) + + +def parse_seed_list(seed_text: str | List[int] | None) -> List[int]: + if seed_text is None: + return [42] + if isinstance(seed_text, list): + return [int(x) for x in seed_text] or [42] + seeds = [int(s.strip()) for s in str(seed_text).split(",") if s.strip()] + return seeds or [42] + + +def select_models( + all_models: List[Dict[str, Any]], + *, + source_tier: str = "all", + models: str | List[str] | None = None, + limit_models: int = 0, +) -> List[Dict[str, Any]]: + selected = list(all_models) + if source_tier != "all": + selected = [m for m in selected if m.get("source_tier") == source_tier] + + if models: + allowed = set(models) if isinstance(models, list) else {x.strip() for x in str(models).split(",") if x.strip()} + selected = [m for m in selected if m["name"] in allowed] + + selected = sorted(selected, key=lambda x: int(x.get("priority", 9999))) + if limit_models > 0: + selected = selected[:limit_models] + return selected diff --git a/pyhazards/benchmarks/wildfire_benchmark/experiment_settings.py b/pyhazards/benchmarks/wildfire_benchmark/experiment_settings.py new file mode 100644 index 00000000..ea76d957 --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/experiment_settings.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Mapping + + +@dataclass +class BenchmarkSection: + name: str = "WildfireBench" + contract_version: str = "track_o_2024_v1" + mode: str = "scaffold_no_data" + task: str = "Track-O" + + +@dataclass +class EvaluationProtocolSection: + discrimination: dict[str, Any] = field( + default_factory=lambda: {"primary": "auprc", "secondary": "auroc"} + ) + reliability: dict[str, Any] = field( + default_factory=lambda: {"metrics": ["brier", "nll", "ece"]} + ) + temporal_consistency: dict[str, Any] = field( + default_factory=lambda: { + "metrics": ["mean_day_to_day_change", "normalized_consistency_score"] + } + ) + + +@dataclass +class ModelSection: + name: str + display_name: str + group: str + source_tier: str + train_unit: str + defaults: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class RunSection: + seed: int = 42 + num_steps: int = 0 + best_step: int = 0 + converged_step: int = 0 + step_name: str = "epoch" + learning_weight: dict[str, Any] = field( + default_factory=lambda: { + "kind": "pos_weight_neg_over_pos", + "value": "to_be_computed_from_real_train_split", + "clip_max": 50.0, + } + ) + + +@dataclass +class WildfireExperimentSetting: + benchmark: BenchmarkSection + evaluation_protocol: EvaluationProtocolSection + model: ModelSection + run: RunSection + metrics: dict[str, Any] = field( + default_factory=lambda: { + "auprc": None, + "auroc": None, + "brier": None, + "nll": None, + "ece": None, + "mean_day_to_day_change": None, + "normalized_consistency_score": None, + } + ) + notes: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + def write_json(self, path: str | Path) -> Path: + target = Path(path) + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(json.dumps(self.to_dict(), indent=2, sort_keys=False), encoding="utf-8") + return target + + +def build_default_experiment_setting( + *, + model_name: str, + display_name: str, + group: str, + source_tier: str, + train_unit: str, + defaults: Mapping[str, Any] | None = None, + seed: int = 42, + num_steps: int = 0, + best_step: int = 0, + converged_step: int = 0, + step_name: str = "epoch", + mode: str = "scaffold_no_data", + task: str = "Track-O", + notes: Mapping[str, Any] | None = None, + metrics: Mapping[str, Any] | None = None, +) -> WildfireExperimentSetting: + setting = WildfireExperimentSetting( + benchmark=BenchmarkSection(mode=mode, task=task), + evaluation_protocol=EvaluationProtocolSection(), + model=ModelSection( + name=model_name, + display_name=display_name, + group=group, + source_tier=source_tier, + train_unit=train_unit, + defaults=dict(defaults or {}), + ), + run=RunSection( + seed=int(seed), + num_steps=int(num_steps), + best_step=int(best_step), + converged_step=int(converged_step), + step_name=step_name, + ), + notes=dict(notes or {}), + ) + if metrics: + setting.metrics.update(dict(metrics)) + return setting + + +def write_experiment_setting( + path: str | Path, + setting: WildfireExperimentSetting, +) -> Path: + return setting.write_json(path) diff --git a/pyhazards/benchmarks/wildfire_benchmark/layout.py b/pyhazards/benchmarks/wildfire_benchmark/layout.py new file mode 100644 index 00000000..11e5f568 --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/layout.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +WILDFIRE_RUNS_ROOT = Path(__file__).resolve().parents[3] / "runs" / "wildfire_benchmark" + + +@dataclass(frozen=True) +class RunPaths: + track: str + run_name: str + model_name: str + seed: int + run_root: Path + model_root: Path + seed_root: Path + experiment_setting_path: Path + history_csv_path: Path + loss_curve_path: Path + metrics_path: Path + model_summary_path: Path + model_template_path: Path + benchmark_summary_path: Path + benchmark_contract_snapshot_path: Path + + +def prepare_run_paths( + track: str, + run_name: str, + model_name: str, + seed: int, + create: bool = True, +) -> RunPaths: + if track not in {"smoke", "real", "archive"}: + raise ValueError(f"Unsupported wildfire benchmark track: {track!r}") + + run_root = WILDFIRE_RUNS_ROOT / track / run_name + model_root = run_root / model_name + seed_root = model_root / f"seed_{int(seed)}" + + paths = RunPaths( + track=track, + run_name=run_name, + model_name=model_name, + seed=int(seed), + run_root=run_root, + model_root=model_root, + seed_root=seed_root, + experiment_setting_path=seed_root / "experiment_setting.json", + history_csv_path=seed_root / "history.csv", + loss_curve_path=seed_root / "loss_curve.png", + metrics_path=seed_root / "metrics.json", + model_summary_path=model_root / "model_summary.json", + model_template_path=model_root / "model_template.json", + benchmark_summary_path=run_root / "benchmark_summary.json", + benchmark_contract_snapshot_path=run_root / "benchmark_contract_snapshot.json", + ) + + if create: + seed_root.mkdir(parents=True, exist_ok=True) + return paths diff --git a/pyhazards/benchmarks/wildfire_benchmark/real_runner.py b/pyhazards/benchmarks/wildfire_benchmark/real_runner.py new file mode 100644 index 00000000..4733f89f --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/real_runner.py @@ -0,0 +1,653 @@ +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence + +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score + +from pyhazards.datasets.wildfire import ( + WildfireTrackO2024RasterDataset, + WildfireTrackO2024TabularDataset, + WildfireTrackO2024TemporalDataset, +) +from pyhazards.models import build_model +from pyhazards.models.convlstm import ConvLSTMTrackOConfig, train_convlstm_track_o +from pyhazards.models.unet import UNetTrackOConfig, train_unet_track_o +from pyhazards.models.unet import binary_ece, normalized_consistency_score +from pyhazards.utils.hardware import auto_device + +from .artifacts import build_model_template, mean_std, plot_loss_curve, write_history_csv, write_json +from .catalog import load_contract, load_model_catalog +from .experiment_settings import build_default_experiment_setting +from .layout import WILDFIRE_RUNS_ROOT, prepare_run_paths + + +REPRESENTATIVE_MODELS = ("logistic_regression", "random_forest", "xgboost", "lightgbm", "unet", "convlstm") + + +def _to_numpy(x: torch.Tensor) -> np.ndarray: + return x.detach().cpu().numpy() + + +def _positive_class_prob(model: torch.nn.Module, x: torch.Tensor) -> np.ndarray: + pred = model(x) + arr = _to_numpy(pred) + if arr.ndim == 2 and arr.shape[1] == 2: + return arr[:, 1].astype(np.float32) + return arr.reshape(-1).astype(np.float32) + + +def _binary_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]: + y_true = np.asarray(y_true, dtype=np.float32).reshape(-1) + y_prob = np.asarray(y_prob, dtype=np.float32).reshape(-1) + y_prob = np.clip(y_prob, 1e-7, 1.0 - 1e-7) + mean_change = float(np.mean(np.abs(np.diff(np.sort(y_prob))))) if len(y_prob) > 1 else 0.0 + + def _safe(callable_obj): + try: + return float(callable_obj()) + except Exception: + return float("nan") + + return { + "auprc": _safe(lambda: average_precision_score(y_true, y_prob)), + "auroc": _safe(lambda: roc_auc_score(y_true, y_prob)), + "brier": _safe(lambda: brier_score_loss(y_true, y_prob)), + "nll": _safe(lambda: log_loss(y_true, y_prob, labels=[0, 1])), + "ece": _safe(lambda: binary_ece(y_true, y_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + +def _catalog_lookup(names: Sequence[str]) -> list[dict[str, Any]]: + catalog = load_model_catalog("main") + allowed = set(names) + selected = [row for row in catalog if row["name"] in allowed] + if len(selected) != len(allowed): + found = {row["name"] for row in selected} + missing = sorted(allowed - found) + raise KeyError(f"Missing models in wildfire main catalog: {missing}") + return selected + + +def _build_setting( + *, + contract: dict[str, Any], + model_spec: dict[str, Any], + seed: int, + num_steps: int, + best_step: int, + converged_step: int, + metrics: dict[str, float], + notes: dict[str, Any], + learning_weight: dict[str, Any], +) -> dict[str, Any]: + setting = build_default_experiment_setting( + model_name=str(model_spec["name"]), + display_name=str(model_spec["display_name"]), + group=str(model_spec["group"]), + source_tier=str(model_spec["source_tier"]), + train_unit=str(model_spec["train_unit"]), + defaults=model_spec.get("defaults", {}), + seed=int(seed), + num_steps=int(num_steps), + best_step=int(best_step), + converged_step=int(converged_step), + step_name=str(model_spec["train_unit"]), + mode=str(contract["mode"]), + task=str(contract["task"]), + notes=notes, + metrics=metrics, + ) + setting.benchmark.contract_version = str(contract["contract_version"]) + setting.run.learning_weight = learning_weight + return setting.to_dict() + + +def _run_logistic_regression( + *, + bundle, + seed: int, + model_spec: dict[str, Any], +) -> dict[str, Any]: + defaults = dict(model_spec.get("defaults", {})) + model = build_model("logistic_regression", task="classification", **defaults) + model.fit_bundle(bundle, train_split="train", val_split="val") + + train_split = bundle.get_split("train") + val_split = bundle.get_split("val") + test_split = bundle.get_split("test") + + train_prob = np.clip(_positive_class_prob(model, train_split.inputs), 1e-7, 1.0 - 1e-7) + val_prob = np.clip(_positive_class_prob(model, val_split.inputs), 1e-7, 1.0 - 1e-7) + test_prob = np.clip(_positive_class_prob(model, test_split.inputs), 1e-7, 1.0 - 1e-7) + + train_true = _to_numpy(train_split.targets).reshape(-1).astype(np.float32) + val_true = _to_numpy(val_split.targets).reshape(-1).astype(np.float32) + test_true = _to_numpy(test_split.targets).reshape(-1).astype(np.float32) + + train_loss = float(log_loss(train_true, train_prob, labels=[0, 1])) + val_loss = float(log_loss(val_true, val_prob, labels=[0, 1])) + fitted_steps = int(np.max(getattr(model.estimator, "n_iter_", np.asarray([1])))) + history = [{"iteration": fitted_steps, "train_loss": train_loss, "val_loss": val_loss}] + + return { + "history": history, + "val_metrics": _binary_metrics(val_true, val_prob), + "test_metrics": _binary_metrics(test_true, test_prob), + "best_step": fitted_steps, + "converged_step": fitted_steps, + "learning_weight": {"kind": "class_weight", "value": defaults.get("class_weight", "balanced")}, + "notes": { + "model_source": "pyhazards.models.logistic_regression", + "device": "cpu", + "gpu_assignment": None, + }, + } + + +def _run_random_forest( + *, + bundle, + seed: int, + model_spec: dict[str, Any], +) -> dict[str, Any]: + defaults = dict(model_spec.get("defaults", {})) + model = build_model("random_forest", task="classification", **defaults) + model.fit_bundle(bundle, train_split="train", val_split="val") + + train_split = bundle.get_split("train") + val_split = bundle.get_split("val") + test_split = bundle.get_split("test") + + train_prob = np.clip(_positive_class_prob(model, train_split.inputs), 1e-7, 1.0 - 1e-7) + val_prob = np.clip(_positive_class_prob(model, val_split.inputs), 1e-7, 1.0 - 1e-7) + test_prob = np.clip(_positive_class_prob(model, test_split.inputs), 1e-7, 1.0 - 1e-7) + + train_true = _to_numpy(train_split.targets).reshape(-1).astype(np.float32) + val_true = _to_numpy(val_split.targets).reshape(-1).astype(np.float32) + test_true = _to_numpy(test_split.targets).reshape(-1).astype(np.float32) + + train_loss = float(log_loss(train_true, train_prob, labels=[0, 1])) + val_loss = float(log_loss(val_true, val_prob, labels=[0, 1])) + n_estimators = int(getattr(model.estimator, 'n_estimators', defaults.get('n_estimators', 500))) + history = [{"tree": n_estimators, "train_loss": train_loss, "val_loss": val_loss}] + + return { + "history": history, + "val_metrics": _binary_metrics(val_true, val_prob), + "test_metrics": _binary_metrics(test_true, test_prob), + "best_step": n_estimators, + "converged_step": n_estimators, + "learning_weight": {"kind": "class_weight", "value": defaults.get("class_weight", "balanced_subsample")}, + "notes": { + "model_source": "pyhazards.models.random_forest", + "device": "cpu", + "gpu_assignment": None, + }, + } + + +def _run_xgboost( + *, + bundle, + seed: int, + model_spec: dict[str, Any], + num_boost_round: int, +) -> dict[str, Any]: + import xgboost as xgb + + defaults = dict(model_spec.get("defaults", {})) + x_train = _to_numpy(bundle.get_split("train").inputs) + y_train = _to_numpy(bundle.get_split("train").targets).reshape(-1) + x_val = _to_numpy(bundle.get_split("val").inputs) + y_val = _to_numpy(bundle.get_split("val").targets).reshape(-1) + x_test = _to_numpy(bundle.get_split("test").inputs) + y_test = _to_numpy(bundle.get_split("test").targets).reshape(-1) + + dtrain = xgb.DMatrix(x_train, label=y_train) + dval = xgb.DMatrix(x_val, label=y_val) + dtest = xgb.DMatrix(x_test, label=y_test) + + params = { + "objective": "binary:logistic", + "eval_metric": "logloss", + "max_depth": int(defaults.get("max_depth", 8)), + "eta": float(defaults.get("eta", 0.05)), + "subsample": float(defaults.get("subsample", 0.8)), + "colsample_bytree": float(defaults.get("colsample_bytree", 0.8)), + "seed": int(seed), + } + evals_result: dict[str, Any] = {} + booster = xgb.train( + params=params, + dtrain=dtrain, + num_boost_round=int(num_boost_round), + evals=[(dtrain, "train"), (dval, "val")], + evals_result=evals_result, + verbose_eval=False, + ) + + train_curve = [float(v) for v in evals_result.get("train", {}).get("logloss", [])] + val_curve = [float(v) for v in evals_result.get("val", {}).get("logloss", [])] + history = [ + {"round": idx + 1, "train_loss": tr, "val_loss": va} + for idx, (tr, va) in enumerate(zip(train_curve, val_curve, strict=True)) + ] + best_step = int(np.argmin(val_curve) + 1) if val_curve else 1 + test_prob = np.clip(np.asarray(booster.predict(dtest), dtype=np.float32), 1e-7, 1.0 - 1e-7) + val_prob = np.clip(np.asarray(booster.predict(dval), dtype=np.float32), 1e-7, 1.0 - 1e-7) + + return { + "history": history, + "val_metrics": _binary_metrics(y_val, val_prob), + "test_metrics": _binary_metrics(y_test, test_prob), + "best_step": best_step, + "converged_step": len(history), + "learning_weight": {"kind": "native_binary_objective", "value": "binary:logistic"}, + "notes": { + "model_source": "pyhazards.models.xgboost", + "device": "cpu", + "gpu_assignment": None, + }, + } + + +def _run_lightgbm( + *, + bundle, + seed: int, + model_spec: dict[str, Any], + num_boost_round: int, +) -> dict[str, Any]: + import lightgbm as lgb + + defaults = dict(model_spec.get("defaults", {})) + x_train = _to_numpy(bundle.get_split("train").inputs) + y_train = _to_numpy(bundle.get_split("train").targets).reshape(-1) + x_val = _to_numpy(bundle.get_split("val").inputs) + y_val = _to_numpy(bundle.get_split("val").targets).reshape(-1) + x_test = _to_numpy(bundle.get_split("test").inputs) + y_test = _to_numpy(bundle.get_split("test").targets).reshape(-1) + + dtrain = lgb.Dataset(x_train, label=y_train) + dval = lgb.Dataset(x_val, label=y_val, reference=dtrain) + evals_result: dict[str, Any] = {} + train_pos = max(float(y_train.sum()), 1.0) + train_neg = max(float(y_train.size - y_train.sum()), 1.0) + scale_pos_weight = float(defaults.get("scale_pos_weight", min(train_neg / train_pos, 500.0))) + params = { + "objective": "binary", + "metric": "binary_logloss", + "num_leaves": int(defaults.get("num_leaves", 15)), + "learning_rate": float(defaults.get("learning_rate", 0.03)), + "feature_fraction": float(defaults.get("feature_fraction", 0.8)), + "bagging_fraction": float(defaults.get("bagging_fraction", 0.8)), + "bagging_freq": int(defaults.get("bagging_freq", 1)), + "min_data_in_leaf": int(defaults.get("min_data_in_leaf", 200)), + "min_sum_hessian_in_leaf": float(defaults.get("min_sum_hessian_in_leaf", 1e-3)), + "lambda_l2": float(defaults.get("lambda_l2", 1.0)), + "max_depth": int(defaults.get("max_depth", -1)), + "scale_pos_weight": scale_pos_weight, + "seed": int(seed), + "verbose": -1, + "force_col_wise": True, + } + booster = lgb.train( + params=params, + train_set=dtrain, + num_boost_round=int(num_boost_round), + valid_sets=[dtrain, dval], + valid_names=["train", "val"], + callbacks=[ + lgb.log_evaluation(period=0), + lgb.record_evaluation(evals_result), + ], + ) + + train_curve = [float(v) for v in evals_result.get("train", {}).get("binary_logloss", [])] + val_curve = [float(v) for v in evals_result.get("val", {}).get("binary_logloss", [])] + history = [ + {"round": idx + 1, "train_loss": tr, "val_loss": va} + for idx, (tr, va) in enumerate(zip(train_curve, val_curve, strict=True)) + ] + best_step = int(np.argmin(val_curve) + 1) if val_curve else len(history) + val_prob = np.clip(np.asarray(booster.predict(x_val), dtype=np.float32), 1e-7, 1.0 - 1e-7) + test_prob = np.clip(np.asarray(booster.predict(x_test), dtype=np.float32), 1e-7, 1.0 - 1e-7) + + return { + "history": history, + "val_metrics": _binary_metrics(y_val, val_prob), + "test_metrics": _binary_metrics(y_test, test_prob), + "best_step": best_step, + "converged_step": len(history), + "learning_weight": {"kind": "scale_pos_weight", "value": float(scale_pos_weight), "derived_from": "train_neg_over_pos_clipped"}, + "notes": { + "model_source": "pyhazards.models.lightgbm", + "device": "cpu", + "gpu_assignment": None, + }, + } + + +def _run_unet( + *, + bundle, + seed: int, + device: str, + max_epochs: int, + patience: int, +) -> dict[str, Any]: + train_split = bundle.get_split("train") + val_split = bundle.get_split("val") + test_split = bundle.get_split("test") + + cfg = UNetTrackOConfig( + in_channels=int(bundle.feature_spec.channels or train_split.inputs.shape[1]), + batch_size=4, + max_epochs=int(max_epochs), + early_stopping_rounds=int(patience), + seed=int(seed), + device=device, + ) + model, history, val_metrics, best_epoch, pos_weight = train_unet_track_o( + _to_numpy(train_split.inputs), + _to_numpy(train_split.targets), + _to_numpy(val_split.inputs), + _to_numpy(val_split.targets), + cfg, + ) + + with torch.no_grad(): + logits = model(test_split.inputs.to(torch.device(device if str(device).startswith("cuda") and torch.cuda.is_available() else "cpu"))) + test_prob = torch.sigmoid(logits).detach().cpu().numpy().reshape(-1) + test_true = _to_numpy(test_split.targets).reshape(-1) + + return { + "history": history, + "val_metrics": val_metrics, + "test_metrics": _binary_metrics(test_true, test_prob), + "best_step": int(best_epoch), + "converged_step": len(history), + "learning_weight": {"kind": "pos_weight_neg_over_pos", "value": float(pos_weight), "clip_max": float(cfg.pos_weight_clip_max)}, + "notes": { + "model_source": "pyhazards.models.unet", + "device": device, + "gpu_assignment": device if str(device).startswith("cuda") else None, + }, + } + + +def _run_convlstm( + *, + bundle, + seed: int, + device: str, + max_epochs: int, + patience: int, + history_len: int, +) -> dict[str, Any]: + train_split = bundle.get_split("train") + val_split = bundle.get_split("val") + test_split = bundle.get_split("test") + + cfg = ConvLSTMTrackOConfig( + seq_len=int(history_len), + in_channels=int(bundle.feature_spec.channels or train_split.inputs.shape[2]), + batch_size=2, + max_epochs=int(max_epochs), + early_stopping_rounds=int(patience), + seed=int(seed), + device=device, + ) + model, history, val_metrics, best_epoch, pos_weight = train_convlstm_track_o( + _to_numpy(train_split.inputs), + _to_numpy(train_split.targets), + _to_numpy(val_split.inputs), + _to_numpy(val_split.targets), + cfg, + ) + + eval_device = torch.device(device if str(device).startswith("cuda") and torch.cuda.is_available() else "cpu") + with torch.no_grad(): + logits = model(test_split.inputs.to(eval_device)) + test_prob = torch.sigmoid(logits).detach().cpu().numpy().reshape(-1) + test_true = _to_numpy(test_split.targets).reshape(-1) + + return { + "history": history, + "val_metrics": val_metrics, + "test_metrics": _binary_metrics(test_true, test_prob), + "best_step": int(best_epoch), + "converged_step": len(history), + "learning_weight": {"kind": "pos_weight_neg_over_pos", "value": float(pos_weight), "clip_max": float(cfg.pos_weight_clip_max)}, + "notes": { + "model_source": "pyhazards.models.convlstm", + "device": device, + "gpu_assignment": device if str(device).startswith("cuda") else None, + }, + } + + +def _write_per_seed_outputs( + *, + contract: dict[str, Any], + model_spec: dict[str, Any], + seed: int, + run_name: str, + result: dict[str, Any], + notes_extra: dict[str, Any], +) -> dict[str, Any]: + paths = prepare_run_paths(track="real", run_name=run_name, model_name=str(model_spec["name"]), seed=int(seed), create=True) + write_history_csv(paths.history_csv_path, result["history"]) + plot_loss_curve(result["history"], str(model_spec["train_unit"]), paths.loss_curve_path, f"{model_spec['display_name']} ({model_spec['train_unit']})") + metrics_payload = { + "val": result["val_metrics"], + "test": result["test_metrics"], + "best_step": int(result["best_step"]), + "converged_step": int(result["converged_step"]), + } + write_json(paths.metrics_path, metrics_payload) + + setting = _build_setting( + contract=contract, + model_spec=model_spec, + seed=int(seed), + num_steps=int(result["converged_step"]), + best_step=int(result["best_step"]), + converged_step=int(result["converged_step"]), + metrics=result["test_metrics"], + notes={**result["notes"], **notes_extra, "val_metrics": result["val_metrics"]}, + learning_weight=result["learning_weight"], + ) + write_json(paths.experiment_setting_path, setting) + return {"paths": paths, "metrics": metrics_payload, "setting": setting} + + +def run_real_baselines( + *, + cache_dir: str | Path = "/home/runyang/my-copy/data_cache/wildfire_2024_v1", + run_name: str = "track_o_2024_real_v1_first4_dryrun", + models: Sequence[str] | None = None, + seed: int = 42, + train_limit_days: int | None = None, + val_limit_days: int | None = None, + test_limit_days: int | None = None, + tabular_downsample: int = 8, + raster_downsample: int = 4, + temporal_downsample: int = 8, + temporal_history: int = 6, + xgboost_rounds: int = 120, + lightgbm_rounds: int = 120, + unet_epochs: int = 12, + convlstm_epochs: int = 12, + deep_patience: int = 4, + device: str | None = None, +) -> Path: + cache_root = Path(cache_dir) + contract = load_contract(Path(__file__).resolve().parents[2] / "configs" / "wildfire_benchmark" / "track_o_2024_real_v1.json") + selected_names = tuple(models or REPRESENTATIVE_MODELS) + model_specs = _catalog_lookup(selected_names) + + run_root = WILDFIRE_RUNS_ROOT / "real" / run_name + run_root.mkdir(parents=True, exist_ok=True) + write_json(run_root / "benchmark_contract_snapshot.json", contract) + + dataset_common = { + "cache_dir": str(cache_root), + "train_limit_days": train_limit_days, + "val_limit_days": val_limit_days, + "test_limit_days": test_limit_days, + } + bundles: dict[str, Any] = {} + if any(name in selected_names for name in ("logistic_regression", "random_forest", "xgboost", "lightgbm")): + bundles["tabular"] = WildfireTrackO2024TabularDataset( + downsample_factor=tabular_downsample, + **dataset_common, + ).load() + if "unet" in selected_names: + bundles["raster"] = WildfireTrackO2024RasterDataset( + downsample_factor=raster_downsample, + **dataset_common, + ).load() + if "convlstm" in selected_names: + bundles["temporal"] = WildfireTrackO2024TemporalDataset( + history=temporal_history, + downsample_factor=temporal_downsample, + **dataset_common, + ).load() + + device_text = str(device or auto_device()) + benchmark_rows: List[dict[str, Any]] = [] + templates_index: dict[str, Any] = {} + + for model_spec in model_specs: + name = str(model_spec["name"]) + model_root = run_root / name + model_root.mkdir(parents=True, exist_ok=True) + template = build_model_template(contract, model_spec) + templates_index[name] = template + write_json(model_root / "model_template.json", template) + + if name == "logistic_regression": + result = _run_logistic_regression(bundle=bundles["tabular"], seed=seed, model_spec=model_spec) + dataset_meta = bundles["tabular"].metadata + elif name == "random_forest": + result = _run_random_forest(bundle=bundles["tabular"], seed=seed, model_spec=model_spec) + dataset_meta = bundles["tabular"].metadata + elif name == "xgboost": + result = _run_xgboost(bundle=bundles["tabular"], seed=seed, model_spec=model_spec, num_boost_round=xgboost_rounds) + dataset_meta = bundles["tabular"].metadata + elif name == "lightgbm": + result = _run_lightgbm(bundle=bundles["tabular"], seed=seed, model_spec=model_spec, num_boost_round=lightgbm_rounds) + dataset_meta = bundles["tabular"].metadata + elif name == "unet": + result = _run_unet(bundle=bundles["raster"], seed=seed, device=device_text, max_epochs=unet_epochs, patience=deep_patience) + dataset_meta = bundles["raster"].metadata + elif name == "convlstm": + result = _run_convlstm( + bundle=bundles["temporal"], + seed=seed, + device=device_text, + max_epochs=convlstm_epochs, + patience=deep_patience, + history_len=temporal_history, + ) + dataset_meta = bundles["temporal"].metadata + else: + raise ValueError(f"Unsupported representative model: {name}") + + has_static_fuel = bool(dataset_meta.get("has_static_fuel", False)) + payload = _write_per_seed_outputs( + contract=contract, + model_spec=model_spec, + seed=seed, + run_name=run_name, + result=result, + notes_extra={ + "cache_root": str(cache_root), + "dataset_metadata": dataset_meta, + "split_version": "cache_2024_v1", + "feature_set_version": "weather_plus_fuel_v1" if has_static_fuel else "weather_only_v1_static_fuel_pending", + "static_fuel_status": "aligned" if has_static_fuel else "manifest_only", + }, + ) + + metric_stats = {k: mean_std([float(v)]) for k, v in result["test_metrics"].items()} + write_json( + model_root / "model_summary.json", + { + "model": { + "name": name, + "display_name": model_spec["display_name"], + "group": model_spec["group"], + "source_tier": model_spec["source_tier"], + "train_unit": model_spec["train_unit"], + }, + "mode": contract["mode"], + "n_seeds": 1, + "seeds": [int(seed)], + "metrics_mean_std": metric_stats, + "per_seed": [ + { + "seed": int(seed), + "best_step": int(result["best_step"]), + "converged_step": int(result["converged_step"]), + "train_unit": model_spec["train_unit"], + **result["test_metrics"], + } + ], + }, + ) + benchmark_rows.append( + { + "name": name, + "display_name": model_spec["display_name"], + "group": model_spec["group"], + "source_tier": model_spec["source_tier"], + "train_unit": model_spec["train_unit"], + "auprc_mean": metric_stats.get("auprc", {}).get("mean"), + "auprc_std": metric_stats.get("auprc", {}).get("std"), + "auroc_mean": metric_stats.get("auroc", {}).get("mean"), + "auroc_std": metric_stats.get("auroc", {}).get("std"), + "brier_mean": metric_stats.get("brier", {}).get("mean"), + "nll_mean": metric_stats.get("nll", {}).get("mean"), + "ece_mean": metric_stats.get("ece", {}).get("mean"), + "normalized_consistency_score_mean": metric_stats.get("normalized_consistency_score", {}).get("mean"), + } + ) + + write_json( + run_root / "benchmark_summary.json", + { + "benchmark": { + "name": contract["benchmark_name"], + "contract_version": contract["contract_version"], + "mode": contract["mode"], + "task": contract["task"], + "generated_at": datetime.now().isoformat(), + "note": "First real-data dry run on the 2024 wildfire cache.", + "cache_root": str(cache_root), + }, + "models_selected": list(selected_names), + "n_models": len(selected_names), + "seeds": [int(seed)], + "rows": benchmark_rows, + }, + ) + write_json( + run_root / "experiment_templates.json", + { + "template_version": "track_o_model_template_v1", + "generated_at": datetime.now().isoformat(), + "models": templates_index, + }, + ) + return run_root + + +__all__ = ["REPRESENTATIVE_MODELS", "run_real_baselines"] diff --git a/pyhazards/benchmarks/wildfire_benchmark/runner.py b/pyhazards/benchmarks/wildfire_benchmark/runner.py new file mode 100644 index 00000000..aaa16869 --- /dev/null +++ b/pyhazards/benchmarks/wildfire_benchmark/runner.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Dict, List + +from .artifacts import ( + build_experiment_setting_from_run_output, + build_model_template, + mean_std, + plot_loss_curve, + write_history_csv, + write_json, +) +from .catalog import load_contract, load_model_catalog, parse_seed_list, select_models +from .layout import WILDFIRE_RUNS_ROOT, prepare_run_paths + + +def run_smoke_batch( + *, + adapter_factory: Callable[[Dict[str, Any], Dict[str, Any], Dict[str, int]], Any], + run_name: str | None = None, + track: str = "smoke", + catalog_kind: str = "main", + catalog_path: str | Path | None = None, + contract_path: str | Path | None = None, + source_tier: str = "all", + models: str | List[str] | None = None, + seeds: str | List[int] | None = None, + limit_models: int = 0, + step_limits: Dict[str, int] | None = None, +) -> Path: + contract = load_contract(contract_path) + catalog = load_model_catalog(catalog_kind, catalog_path) + selected_models = select_models(catalog, source_tier=source_tier, models=models, limit_models=limit_models) + if not selected_models: + raise ValueError("No wildfire benchmark models selected.") + + seed_list = parse_seed_list(seeds) + step_limits = step_limits or {"epoch": 60, "round": 300, "iteration": 250, "tree": 300} + run_name = run_name or f"smoke_batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + run_root = WILDFIRE_RUNS_ROOT / track / run_name + run_root.mkdir(parents=True, exist_ok=True) + write_json(run_root / "benchmark_contract_snapshot.json", contract) + + benchmark_rows: List[Dict[str, Any]] = [] + templates_index: Dict[str, Dict[str, Any]] = {} + + for model_spec in selected_models: + model_name = str(model_spec["name"]) + model_root = run_root / model_name + model_root.mkdir(parents=True, exist_ok=True) + + template = build_model_template(contract, model_spec) + templates_index[model_name] = template + write_json(model_root / "model_template.json", template) + + metric_pool: Dict[str, List[float]] = {} + per_seed_rows: List[Dict[str, Any]] = [] + + for seed in seed_list: + paths = prepare_run_paths(track=track, run_name=run_name, model_name=model_name, seed=int(seed), create=True) + adapter = adapter_factory(model_spec, contract, step_limits) + run_output = adapter.run(seed=int(seed)) + + write_history_csv(paths.history_csv_path, run_output.history) + plot_loss_curve(run_output.history, run_output.train_unit, paths.loss_curve_path, f"{model_spec['display_name']} ({run_output.train_unit})") + + if hasattr(adapter, 'build_experiment_setting'): + experiment_setting = adapter.build_experiment_setting(seed=int(seed), run_output=run_output) + else: + experiment_setting = build_experiment_setting_from_run_output( + contract=contract, + model_spec=model_spec, + seed=int(seed), + run_output=run_output, + ) + write_json(paths.experiment_setting_path, experiment_setting) + write_json(paths.metrics_path, run_output.metrics) + + for key, value in run_output.metrics.items(): + metric_pool.setdefault(key, []).append(float(value)) + per_seed_rows.append( + { + 'seed': int(seed), + 'best_step': int(run_output.best_step), + 'converged_step': int(run_output.converged_step), + 'train_unit': run_output.train_unit, + **run_output.metrics, + } + ) + + metric_stats = {k: mean_std(v) for k, v in metric_pool.items()} + write_json( + model_root / 'model_summary.json', + { + 'model': { + 'name': model_name, + 'display_name': model_spec['display_name'], + 'group': model_spec['group'], + 'source_tier': model_spec['source_tier'], + 'train_unit': model_spec['train_unit'], + }, + 'mode': contract['mode'], + 'n_seeds': len(seed_list), + 'seeds': seed_list, + 'metrics_mean_std': metric_stats, + 'per_seed': per_seed_rows, + }, + ) + benchmark_rows.append( + { + 'name': model_name, + 'display_name': model_spec['display_name'], + 'group': model_spec['group'], + 'source_tier': model_spec['source_tier'], + 'train_unit': model_spec['train_unit'], + 'auprc_mean': metric_stats.get('auprc', {}).get('mean'), + 'auprc_std': metric_stats.get('auprc', {}).get('std'), + 'auroc_mean': metric_stats.get('auroc', {}).get('mean'), + 'auroc_std': metric_stats.get('auroc', {}).get('std'), + 'brier_mean': metric_stats.get('brier', {}).get('mean'), + 'nll_mean': metric_stats.get('nll', {}).get('mean'), + 'ece_mean': metric_stats.get('ece', {}).get('mean'), + 'normalized_consistency_score_mean': metric_stats.get('normalized_consistency_score', {}).get('mean'), + } + ) + + write_json( + run_root / 'benchmark_summary.json', + { + 'benchmark': { + 'name': contract['benchmark_name'], + 'contract_version': contract['contract_version'], + 'mode': contract['mode'], + 'task': contract['task'], + 'generated_at': datetime.now().isoformat(), + 'note': 'Adapter-level smoke run.', + 'contract_path': str(contract_path) if contract_path else 'pyhazards/configs/wildfire_benchmark/track_o_2024_v1.json', + 'catalog_kind': catalog_kind, + }, + 'models_selected': [m['name'] for m in selected_models], + 'n_models': len(selected_models), + 'seeds': seed_list, + 'rows': benchmark_rows, + }, + ) + + templates_payload = { + 'template_version': 'track_o_model_template_v1', + 'generated_at': datetime.now().isoformat(), + 'models': templates_index, + } + write_json(run_root / 'experiment_templates.json', templates_payload) + if catalog_kind == 'main': + write_json(run_root / 'experiment_templates_22.json', templates_payload) + return run_root diff --git a/pyhazards/configs/wildfire_benchmark/cache_2024_v1.yaml b/pyhazards/configs/wildfire_benchmark/cache_2024_v1.yaml new file mode 100644 index 00000000..9eb76ea6 --- /dev/null +++ b/pyhazards/configs/wildfire_benchmark/cache_2024_v1.yaml @@ -0,0 +1,30 @@ +cache: + root: /home/runyang/my-copy/data_cache/wildfire_2024_v1 + +data: + year: 2024 + weather_dir: /home/runyang/output2024 + weather_glob: pred_2024*.nc + sample_nc: pred_20240101_18.nc + weather_vars: + - T2M + - QV2M + - TQV + - U10M + - V10M + - GWETROOT + - TS + - LAI + - EFLUX + - HFLUX + - SWGNT + - SWTNT + - LWGAB + - LWGEM + firms_daily_dir: /home/runyang/ryang/firms/combine + landfire_tif: /home/runyang/ryang/landfire_fbfm40/LF2024_FBFM13_250_CONUS/Tif/LC24_F13_250.tif + +splits: + train: [2024-01-01, 2024-09-30] + val: [2024-10-01, 2024-10-31] + test: [2024-11-01, 2024-12-31] diff --git a/pyhazards/configs/wildfire_benchmark/model_catalog_22.json b/pyhazards/configs/wildfire_benchmark/model_catalog_22.json new file mode 100644 index 00000000..a1b6122d --- /dev/null +++ b/pyhazards/configs/wildfire_benchmark/model_catalog_22.json @@ -0,0 +1,295 @@ +[ + { + "name": "logistic_regression", + "display_name": "Logistic Regression", + "group": "classical_trees", + "train_unit": "iteration", + "source_tier": "no_official_repo", + "priority": 100, + "defaults": { + "solver": "lbfgs", + "max_iter": 500, + "class_weight": "balanced" + } + }, + { + "name": "random_forest", + "display_name": "Random Forest", + "group": "classical_trees", + "train_unit": "tree", + "source_tier": "no_official_repo", + "priority": 101, + "defaults": { + "n_estimators": 500, + "max_depth": null, + "class_weight": "balanced_subsample" + } + }, + { + "name": "xgboost", + "display_name": "XGBoost", + "group": "classical_trees", + "train_unit": "round", + "source_tier": "official_repo", + "priority": 1, + "defaults": { + "max_depth": 8, + "eta": 0.05, + "subsample": 0.8, + "colsample_bytree": 0.8, + "num_boost_round": 800 + } + }, + { + "name": "lightgbm", + "display_name": "LightGBM", + "group": "classical_trees", + "train_unit": "round", + "source_tier": "official_repo", + "priority": 2, + "defaults": { + "num_leaves": 63, + "learning_rate": 0.05, + "feature_fraction": 0.8, + "bagging_fraction": 0.8, + "num_boost_round": 800 + } + }, + { + "name": "unet", + "display_name": "U-Net", + "group": "segmentation_cnns", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 3, + "defaults": { + "optimizer": "AdamW", + "lr": 0.001, + "max_epochs": 120 + } + }, + { + "name": "resnet18_unet", + "display_name": "ResNet-18 U-Net", + "group": "segmentation_cnns", + "train_unit": "epoch", + "source_tier": "no_official_repo", + "priority": 102, + "defaults": { + "backbone": "resnet18", + "optimizer": "AdamW", + "lr": 0.001, + "max_epochs": 120 + } + }, + { + "name": "attention_unet", + "display_name": "Attention U-Net", + "group": "segmentation_cnns", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 4, + "defaults": { + "optimizer": "AdamW", + "lr": 0.001, + "max_epochs": 120 + } + }, + { + "name": "deeplabv3p", + "display_name": "DeepLabv3+", + "group": "segmentation_cnns", + "train_unit": "epoch", + "source_tier": "paper_only", + "priority": 80, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0003, + "max_epochs": 120 + } + }, + { + "name": "convlstm", + "display_name": "ConvLSTM", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "no_official_repo", + "priority": 103, + "defaults": { + "optimizer": "Adam", + "lr": 0.001, + "max_epochs": 120 + } + }, + { + "name": "mau", + "display_name": "MAU", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 5, + "defaults": { + "optimizer": "Adam", + "lr": 0.0005, + "max_epochs": 120 + } + }, + { + "name": "predrnn_v2", + "display_name": "PredRNN-v2", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 6, + "defaults": { + "optimizer": "Adam", + "lr": 0.0005, + "max_epochs": 120 + } + }, + { + "name": "rainformer", + "display_name": "Rainformer", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "no_official_repo", + "priority": 104, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "earthformer", + "display_name": "Earthformer", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 7, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "swinlstm", + "display_name": "SwinLSTM", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 8, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "earthfarseer", + "display_name": "EarthFarseer", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "paper_only", + "priority": 81, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "convgru_trajgru", + "display_name": "ConvGRU / TrajGRU", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "no_official_repo", + "priority": 105, + "defaults": { + "optimizer": "Adam", + "lr": 0.001, + "max_epochs": 120 + } + }, + { + "name": "tcn", + "display_name": "TCN", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 9, + "defaults": { + "optimizer": "Adam", + "lr": 0.001, + "max_epochs": 120 + } + }, + { + "name": "utae", + "display_name": "UTAE", + "group": "spatiotemporal", + "train_unit": "epoch", + "source_tier": "paper_only", + "priority": 82, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0003, + "max_epochs": 120 + } + }, + { + "name": "segformer", + "display_name": "SegFormer", + "group": "transformers", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 10, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "swin_unet", + "display_name": "Swin-Unet", + "group": "transformers", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 11, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "vit_segmenter", + "display_name": "ViT-based Segmenter", + "group": "transformers", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 12, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "deep_ensemble", + "display_name": "Deep Ensemble", + "group": "uncertainty", + "train_unit": "epoch", + "source_tier": "no_official_repo", + "priority": 106, + "defaults": { + "base_model": "tcn", + "ensemble_size": 5, + "optimizer": "AdamW", + "lr": 0.001, + "max_epochs": 120 + } + } +] diff --git a/pyhazards/configs/wildfire_benchmark/model_catalog_extensions_v1.json b/pyhazards/configs/wildfire_benchmark/model_catalog_extensions_v1.json new file mode 100644 index 00000000..2b51f979 --- /dev/null +++ b/pyhazards/configs/wildfire_benchmark/model_catalog_extensions_v1.json @@ -0,0 +1,284 @@ +[ + { + "name": "cnn_aspp", + "display_name": "CNN-ASPP", + "group": "satellite_remote_sensing", + "train_unit": "epoch", + "source_tier": "paper_only", + "priority": 201, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0003, + "max_epochs": 120 + } + }, + { + "name": "asufm", + "display_name": "ASUFM", + "group": "satellite_remote_sensing", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 202, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "firecastnet", + "display_name": "FireCastNet", + "group": "seasonal_forecasting", + "train_unit": "epoch", + "source_tier": "paper_only", + "priority": 203, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "firepred", + "display_name": "FirePred", + "group": "satellite_remote_sensing", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 203, + "defaults": { + "history": 5, + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "viirs_375m_active_fire", + "display_name": "VIIRS 375 m Active Fire", + "group": "operational_detection", + "train_unit": "epoch", + "source_tier": "official_paper", + "priority": 204, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 80 + } + }, + { + "name": "modis_active_fire_c61", + "display_name": "MODIS Active Fire C6.1", + "group": "operational_detection", + "train_unit": "epoch", + "source_tier": "official_paper", + "priority": 205, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 80 + } + }, + { + "name": "wrf_sfire_adapter", + "display_name": "WRF-SFIRE Adapter", + "group": "physics_simulators", + "train_unit": "iteration", + "source_tier": "official_repo", + "priority": 204, + "defaults": { + "max_iter": 120 + } + }, + { + "name": "forefire_adapter", + "display_name": "ForeFire Adapter", + "group": "physics_simulators", + "train_unit": "iteration", + "source_tier": "official_repo", + "priority": 205, + "defaults": { + "max_iter": 120 + } + }, + { + "name": "wildfiregpt", + "display_name": "WildfireGPT", + "group": "llm_systems", + "train_unit": "iteration", + "source_tier": "official_repo", + "priority": 206, + "defaults": { + "max_iter": 80 + } + }, + { + "name": "gemini_25_pro_wildfire_prompted", + "display_name": "Gemini 2.5 Pro Wildfire Prompted", + "group": "llm_systems", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 206, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0001, + "max_epochs": 60, + "in_channels": 6, + "hidden_dim": 96, + "prompt_dim": 32, + "num_prompt_tokens": 6, + "num_heads": 8, + "dropout": 0.1 + } + }, + { + "name": "llama4_wildfire_prompted", + "display_name": "Llama 4 Wildfire Prompted", + "group": "llm_systems", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 207, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0001, + "max_epochs": 60, + "in_channels": 6, + "hidden_dim": 80, + "prompt_dim": 32, + "num_prompt_tokens": 4, + "num_heads": 8, + "dropout": 0.1 + } + }, + { + "name": "qwen25_vl_wildfire_prompted", + "display_name": "Qwen2.5-VL Wildfire Prompted", + "group": "llm_systems", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 206, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0001, + "max_epochs": 60 + } + }, + { + "name": "internvl3_wildfire_prompted", + "display_name": "InternVL3 Wildfire Prompted", + "group": "llm_systems", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 208, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0001, + "max_epochs": 60, + "in_channels": 6, + "hidden_dim": 96, + "prompt_dim": 32, + "num_prompt_tokens": 5, + "num_heads": 6, + "dropout": 0.1 + } + }, + { + "name": "firemm_ir", + "display_name": "FireMM-IR", + "group": "llm_systems", + "train_unit": "epoch", + "source_tier": "paper_only", + "priority": 207, + "defaults": { + "optimizer": "AdamW", + "lr": 0.0001, + "max_epochs": 60 + } + }, + { + "name": "prithvi_eo_2_tl", + "display_name": "Prithvi-EO-2.0-TL", + "group": "foundation_models", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 208, + "defaults": { + "backbone": "Prithvi-EO-2.0", + "optimizer": "AdamW", + "lr": 0.0001, + "max_epochs": 80 + } + }, + { + "name": "prithvi_burnscars", + "display_name": "Prithvi BurnScars", + "group": "foundation_models", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 209, + "defaults": { + "backbone": "Prithvi-BurnScars", + "optimizer": "AdamW", + "lr": 0.0001, + "max_epochs": 80 + } + }, + { + "name": "prithvi_wxc", + "display_name": "Prithvi-WxC", + "group": "foundation_models", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 210, + "defaults": { + "backbone": "Prithvi-WxC", + "optimizer": "AdamW", + "lr": 0.0001, + "max_epochs": 80 + } + }, + { + "name": "wildfirespreadts", + "display_name": "WildfireSpreadTS", + "group": "satellite_remote_sensing", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 211, + "defaults": { + "history": 4, + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "ts_satfire", + "display_name": "TS-SatFire", + "group": "satellite_remote_sensing", + "train_unit": "epoch", + "source_tier": "official_repo", + "priority": 213, + "defaults": { + "history": 5, + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + }, + { + "name": "wildfire_fpa", + "display_name": "DNN-LSTM-AutoEncoder", + "group": "forecasting_systems", + "train_unit": "epoch", + "source_tier": "paper_only", + "priority": 214, + "defaults": { + "depth": 2, + "hidden_dim": 64, + "activation": "relu", + "dropout": 0.1, + "optimizer": "AdamW", + "lr": 0.0002, + "max_epochs": 120 + } + } +] diff --git a/pyhazards/configs/wildfire_benchmark/track_o_2024_real_v1.json b/pyhazards/configs/wildfire_benchmark/track_o_2024_real_v1.json new file mode 100644 index 00000000..71b0b86d --- /dev/null +++ b/pyhazards/configs/wildfire_benchmark/track_o_2024_real_v1.json @@ -0,0 +1,86 @@ +{ + "benchmark_name": "WildfireBench", + "contract_version": "track_o_2024_real_v1", + "mode": "real_data_v1", + "task": "Track-O", + "description": "Unified real-data wildfire occurrence benchmark protocol for 2024 using FIRMS labels, Prithvi-WxC weather predictions, and LANDFIRE static fuels.", + "data": { + "year": 2024, + "label_source": "/home/runyang/ryang/firms/combine", + "dynamic_feature_sources": [ + "/home/runyang/output2024" + ], + "static_feature_sources": [ + "/home/runyang/ryang/landfire_fbfm40" + ], + "optional_feature_sources": [ + "/home/runyang/ryang/WFIGS_Perimeters/history_2024", + "/home/runyang/ryang/WRC_Housing_Density", + "/home/runyang/ryang/LandScan_Global_2024" + ], + "index_unit": "county_day", + "split": { + "train": ["2024-01-01", "2024-09-30"], + "val": ["2024-10-01", "2024-10-31"], + "test": ["2024-11-01", "2024-12-31"] + }, + "leakage_control": { + "fit_statistics_on_train_only": true, + "no_future_covariates": true, + "fixed_split_files_required_for_real_runs": true + } + }, + "shared_training": { + "dry_run_seed_list": [42], + "final_seed_list": [42, 52, 62, 72, 82], + "optimizer_default": "AdamW", + "learning_rate_default": 0.001, + "weight_decay_default": 0.0001, + "class_imbalance": { + "policy": "pos_weight_neg_over_pos", + "clip_max": 50.0 + }, + "early_stopping": { + "enabled": true, + "monitor": "val_auprc", + "patience": 20, + "min_delta": 0.0001 + }, + "convergence_rule": { + "monitor": "val_loss", + "smoothing_window": 5, + "patience": 20, + "min_improvement": 0.0001 + }, + "report_requirements": { + "report_mean_std_across_seeds": true, + "must_include_train_curve": true, + "must_include_val_curve": true, + "must_log_best_step": true, + "must_log_converged_step": true, + "must_log_device": true, + "must_log_gpu_assignment": true + } + }, + "metrics": { + "primary": ["auprc"], + "secondary": ["auroc"], + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"] + }, + "output_schema": { + "root": "/home/runyang/my-copy/runs/wildfire_benchmark/real/track_o_2024_real_v1", + "per_seed_files": [ + "experiment_setting.json", + "history.csv", + "loss_curve.png", + "metrics.json" + ], + "per_model_files": [ + "model_summary.json" + ], + "benchmark_files": [ + "benchmark_summary.json" + ] + } +} diff --git a/pyhazards/configs/wildfire_benchmark/track_o_2024_v1.json b/pyhazards/configs/wildfire_benchmark/track_o_2024_v1.json new file mode 100644 index 00000000..f20f9bf5 --- /dev/null +++ b/pyhazards/configs/wildfire_benchmark/track_o_2024_v1.json @@ -0,0 +1,82 @@ +{ + "benchmark_name": "WildfireBench", + "contract_version": "track_o_2024_v1", + "mode": "scaffold_no_data", + "task": "Track-O", + "description": "Unified county-day wildfire occurrence benchmark protocol for 2024 only.", + "data": { + "year": 2024, + "label_source": "/home/runyang/ryang/firms_download/combine", + "dynamic_feature_sources": [ + "/home/runyang/output2024" + ], + "static_feature_sources": [ + "/home/runyang/ryang/landfire_fbfm40" + ], + "external_optional_sources": [ + "geo", + "vegetation", + "flood" + ], + "index_unit": "county_day", + "split": { + "train": ["2024-01-01", "2024-09-30"], + "val": ["2024-10-01", "2024-10-31"], + "test": ["2024-11-01", "2024-12-31"] + }, + "leakage_control": { + "fit_statistics_on_train_only": true, + "no_future_covariates": true, + "fixed_split_files_required_for_real_runs": true + } + }, + "shared_training": { + "seed_list": [42, 52, 62, 72, 82], + "optimizer_default": "AdamW", + "learning_rate_default": 0.001, + "weight_decay_default": 0.0001, + "class_imbalance": { + "policy": "pos_weight_neg_over_pos", + "clip_max": 50.0 + }, + "early_stopping": { + "enabled": true, + "monitor": "val_auprc", + "patience": 10, + "min_delta": 0.0005 + }, + "convergence_rule": { + "monitor": "val_loss", + "smoothing_window": 5, + "patience": 5, + "min_improvement": 0.001 + }, + "report_requirements": { + "report_mean_std_across_seeds": true, + "must_include_train_curve": true, + "must_include_val_curve": true, + "must_log_best_step": true, + "must_log_converged_step": true + } + }, + "metrics": { + "primary": ["auprc"], + "secondary": ["auroc"], + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"] + }, + "output_schema": { + "per_seed_files": [ + "experiment_setting.json", + "history.csv", + "loss_curve.png", + "metrics.json" + ], + "per_model_files": [ + "model_summary.json" + ], + "benchmark_files": [ + "benchmark_summary.json" + ] + } +} diff --git a/pyhazards/datasets/__init__.py b/pyhazards/datasets/__init__.py index 1850e566..dc6a4086 100644 --- a/pyhazards/datasets/__init__.py +++ b/pyhazards/datasets/__init__.py @@ -23,7 +23,14 @@ TCBenchAlphaDataset, TropiCycloneNetDataset, ) -from .wildfire import SyntheticWildfireSpreadDataset, SyntheticWildfireSpreadTemporalDataset +from .wildfire import ( + SyntheticWildfireSpreadDataset, + SyntheticWildfireSpreadTemporalDataset, + TrackOSplitConfig, + WildfireTrackO2024RasterDataset, + WildfireTrackO2024TabularDataset, + WildfireTrackO2024TemporalDataset, +) __all__ = [ "DataBundle", @@ -55,6 +62,10 @@ "TropiCycloneNetDataset", "SyntheticWildfireSpreadDataset", "SyntheticWildfireSpreadTemporalDataset", + "TrackOSplitConfig", + "WildfireTrackO2024RasterDataset", + "WildfireTrackO2024TabularDataset", + "WildfireTrackO2024TemporalDataset", ] register_dataset(SyntheticEarthquakeForecastDataset.name, SyntheticEarthquakeForecastDataset) @@ -76,3 +87,6 @@ register_dataset(TropiCycloneNetDataset.name, TropiCycloneNetDataset) register_dataset(SyntheticWildfireSpreadDataset.name, SyntheticWildfireSpreadDataset) register_dataset(SyntheticWildfireSpreadTemporalDataset.name, SyntheticWildfireSpreadTemporalDataset) +register_dataset(WildfireTrackO2024RasterDataset.name, WildfireTrackO2024RasterDataset) +register_dataset(WildfireTrackO2024TabularDataset.name, WildfireTrackO2024TabularDataset) +register_dataset(WildfireTrackO2024TemporalDataset.name, WildfireTrackO2024TemporalDataset) diff --git a/pyhazards/datasets/wildfire/__init__.py b/pyhazards/datasets/wildfire/__init__.py index 7efeb044..dc28b6d9 100644 --- a/pyhazards/datasets/wildfire/__init__.py +++ b/pyhazards/datasets/wildfire/__init__.py @@ -3,6 +3,12 @@ import torch from ..base import DataBundle, DataSplit, Dataset, FeatureSpec, LabelSpec +from .real_track_o_2024 import ( + TrackOSplitConfig, + WildfireTrackO2024RasterDataset, + WildfireTrackO2024TabularDataset, + WildfireTrackO2024TemporalDataset, +) class SyntheticWildfireSpreadDataset(Dataset): @@ -144,4 +150,11 @@ def _load(self) -> DataBundle: ) -__all__ = ["SyntheticWildfireSpreadDataset", "SyntheticWildfireSpreadTemporalDataset"] +__all__ = [ + "SyntheticWildfireSpreadDataset", + "SyntheticWildfireSpreadTemporalDataset", + "TrackOSplitConfig", + "WildfireTrackO2024RasterDataset", + "WildfireTrackO2024TabularDataset", + "WildfireTrackO2024TemporalDataset", +] diff --git a/pyhazards/datasets/wildfire/real_track_o_2024.py b/pyhazards/datasets/wildfire/real_track_o_2024.py new file mode 100644 index 00000000..b73444d6 --- /dev/null +++ b/pyhazards/datasets/wildfire/real_track_o_2024.py @@ -0,0 +1,451 @@ +from __future__ import annotations + +import json +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Sequence + +import numpy as np +import torch + +from ..base import DataBundle, DataSplit, Dataset, FeatureSpec, LabelSpec + + +def _read_lines(path: Path) -> list[str]: + if not path.exists(): + raise FileNotFoundError(f"Expected split file not found: {path}") + return [line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip()] + + +def _load_weather_vars(cache_root: Path) -> list[str]: + payload = json.loads((cache_root / "metadata" / "vars.json").read_text(encoding="utf-8")) + return list(payload["weather_vars"]) + + +def _load_lat_lon(cache_root: Path) -> tuple[np.ndarray, np.ndarray]: + lat = np.load(cache_root / "metadata" / "lat.npy") + lon = np.load(cache_root / "metadata" / "lon.npy") + return np.asarray(lat, dtype=np.float32), np.asarray(lon, dtype=np.float32) + + +def _subset_dates(dates: Sequence[str], limit: int | None) -> list[str]: + if limit is None or int(limit) <= 0: + return list(dates) + return list(dates[: int(limit)]) + + +def _crop_hw_to_multiple(arr: np.ndarray, multiple: int = 4) -> np.ndarray: + if arr.ndim == 3: + _, h, w = arr.shape + h2 = h - (h % multiple) + w2 = w - (w % multiple) + return arr[:, : max(h2, multiple), : max(w2, multiple)] + if arr.ndim == 2: + h, w = arr.shape + h2 = h - (h % multiple) + w2 = w - (w % multiple) + return arr[: max(h2, multiple), : max(w2, multiple)] + raise ValueError(f"Unsupported array rank for cropping: {arr.ndim}") + + +def _spatial_downsample(arr: np.ndarray, factor: int) -> np.ndarray: + if factor <= 1: + out = np.asarray(arr, dtype=np.float32) + elif arr.ndim == 3: + out = np.asarray(arr[:, ::factor, ::factor], dtype=np.float32) + elif arr.ndim == 2: + out = np.asarray(arr[::factor, ::factor], dtype=np.float32) + else: + raise ValueError(f"Unsupported array rank for downsampling: {arr.ndim}") + return np.asarray(_crop_hw_to_multiple(out, multiple=4), dtype=np.float32) + + +def _compute_channel_stats(cache_root: Path, dates: Sequence[str], downsample_factor: int) -> tuple[np.ndarray, np.ndarray]: + weather_dir = cache_root / "met" + example = _spatial_downsample(np.load(weather_dir / f"{dates[0]}.npy"), downsample_factor) + channels = int(example.shape[0]) + sums = np.zeros((channels,), dtype=np.float64) + sums_sq = np.zeros((channels,), dtype=np.float64) + count = 0 + for date in dates: + arr = _spatial_downsample(np.load(weather_dir / f"{date}.npy"), downsample_factor) + flat = arr.reshape(channels, -1).astype(np.float64, copy=False) + sums += flat.sum(axis=1) + sums_sq += np.square(flat).sum(axis=1) + count += flat.shape[1] + mean = sums / max(count, 1) + var = np.maximum(sums_sq / max(count, 1) - np.square(mean), 1e-12) + std = np.sqrt(var) + return mean.astype(np.float32), std.astype(np.float32) + + +def _normalize_weather(arr: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray: + return ((arr - mean[:, None, None]) / std[:, None, None]).astype(np.float32) + + +def _load_static_fuel(cache_root: Path, downsample_factor: int) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]: + fuel_path = cache_root / "static" / "fuel.npy" + fuel_mask_path = cache_root / "static" / "fuel_mask.npy" + if not fuel_path.exists(): + return None, None + fuel = np.load(fuel_path) + fuel_mask = np.load(fuel_mask_path) if fuel_mask_path.exists() else (fuel > 0).astype(np.uint8) + fuel = _spatial_downsample(fuel.astype(np.float32), downsample_factor) + fuel_mask = _spatial_downsample(fuel_mask.astype(np.float32), downsample_factor) + fuel_mask = (fuel_mask > 0.5).astype(np.float32) + fuel = np.where(fuel_mask > 0, fuel / 100.0, 0.0).astype(np.float32) + return fuel, fuel_mask + + +def _date_to_cyclical_features(date_text: str) -> tuple[float, float]: + month = int(date_text[5:7]) + day = int(date_text[8:10]) + day_of_year = (month - 1) * 31 + day + angle = 2.0 * math.pi * (float(day_of_year) / 366.0) + return float(math.sin(angle)), float(math.cos(angle)) + + +@dataclass(frozen=True) +class TrackOSplitConfig: + train_limit_days: int | None = None + val_limit_days: int | None = None + test_limit_days: int | None = None + + +class _WildfireTrackOBase(Dataset): + name = "wildfire_track_o_2024_base" + + def __init__( + self, + cache_dir: str | None = None, + *, + downsample_factor: int = 1, + train_limit_days: int | None = None, + val_limit_days: int | None = None, + test_limit_days: int | None = None, + ): + super().__init__(cache_dir=cache_dir) + self.cache_root = Path(cache_dir or "/home/runyang/my-copy/data_cache/wildfire_2024_v1") + self.downsample_factor = max(1, int(downsample_factor)) + self.split_cfg = TrackOSplitConfig( + train_limit_days=train_limit_days, + val_limit_days=val_limit_days, + test_limit_days=test_limit_days, + ) + + def _load_split_dates(self) -> dict[str, list[str]]: + split_root = self.cache_root / "splits" + return { + "train": _subset_dates(_read_lines(split_root / "train_dates.txt"), self.split_cfg.train_limit_days), + "val": _subset_dates(_read_lines(split_root / "val_dates.txt"), self.split_cfg.val_limit_days), + "test": _subset_dates(_read_lines(split_root / "test_dates.txt"), self.split_cfg.test_limit_days), + } + + +class WildfireTrackO2024RasterDataset(_WildfireTrackOBase): + name = "wildfire_track_o_2024_raster" + + def __init__( + self, + cache_dir: str | None = None, + *, + downsample_factor: int = 4, + train_limit_days: int | None = None, + val_limit_days: int | None = None, + test_limit_days: int | None = None, + ): + super().__init__( + cache_dir=cache_dir, + downsample_factor=downsample_factor, + train_limit_days=train_limit_days, + val_limit_days=val_limit_days, + test_limit_days=test_limit_days, + ) + + def _load(self) -> DataBundle: + split_dates = self._load_split_dates() + weather_vars = _load_weather_vars(self.cache_root) + mean, std = _compute_channel_stats(self.cache_root, split_dates["train"], self.downsample_factor) + fuel, fuel_mask = _load_static_fuel(self.cache_root, self.downsample_factor) + + splits: dict[str, DataSplit] = {} + for split_name, dates in split_dates.items(): + x_rows: list[np.ndarray] = [] + y_rows: list[np.ndarray] = [] + for date in dates: + x = _spatial_downsample(np.load(self.cache_root / "met" / f"{date}.npy"), self.downsample_factor) + x = _normalize_weather(x, mean, std) + if fuel is not None and fuel_mask is not None: + x = np.concatenate([x, fuel[None, :, :], fuel_mask[None, :, :]], axis=0) + x_rows.append(x.astype(np.float32)) + y = _spatial_downsample(np.load(self.cache_root / "labels" / f"{date}.npy"), self.downsample_factor) + y_rows.append(y[None, :, :].astype(np.float32)) + + x_np = np.stack(x_rows, axis=0).astype(np.float32) + y_np = np.stack(y_rows, axis=0).astype(np.float32) + splits[split_name] = DataSplit( + inputs=torch.from_numpy(x_np), + targets=torch.from_numpy(y_np), + metadata={"dates": list(dates)}, + ) + + sample_shape = splits["train"].inputs.shape + return DataBundle( + splits=splits, + feature_spec=FeatureSpec( + channels=int(sample_shape[1]), + description="Daily gridded wildfire covariates from the 2024 Prithvi-WxC weather cache.", + extra={ + "height": int(sample_shape[2]), + "width": int(sample_shape[3]), + "downsample_factor": self.downsample_factor, + "weather_vars": weather_vars, + "static_feature_names": ["fuel_class_scaled", "fuel_valid_mask"] if fuel is not None else [], + }, + ), + label_spec=LabelSpec( + num_targets=1, + task_type="segmentation", + description="Binary daily wildfire occurrence grid aligned to the benchmark cache.", + ), + metadata={ + "dataset": self.name, + "cache_root": str(self.cache_root), + "has_static_fuel": fuel is not None, + "normalization": { + "mean": mean.tolist(), + "std": std.tolist(), + "fit_split": "train", + }, + "splits": {k: len(v) for k, v in split_dates.items()}, + }, + ) + + +class WildfireTrackO2024TemporalDataset(_WildfireTrackOBase): + name = "wildfire_track_o_2024_temporal" + + def __init__( + self, + cache_dir: str | None = None, + *, + history: int = 6, + downsample_factor: int = 8, + train_limit_days: int | None = None, + val_limit_days: int | None = None, + test_limit_days: int | None = None, + ): + super().__init__( + cache_dir=cache_dir, + downsample_factor=downsample_factor, + train_limit_days=train_limit_days, + val_limit_days=val_limit_days, + test_limit_days=test_limit_days, + ) + self.history = int(history) + + def _load(self) -> DataBundle: + split_dates = self._load_split_dates() + weather_vars = _load_weather_vars(self.cache_root) + mean, std = _compute_channel_stats(self.cache_root, split_dates["train"], self.downsample_factor) + fuel, fuel_mask = _load_static_fuel(self.cache_root, self.downsample_factor) + static_channels = None + if fuel is not None and fuel_mask is not None: + static_channels = np.stack([fuel, fuel_mask], axis=0).astype(np.float32) + + splits: dict[str, DataSplit] = {} + for split_name, dates in split_dates.items(): + x_rows: list[np.ndarray] = [] + y_rows: list[np.ndarray] = [] + used_dates: list[str] = [] + + for idx in range(self.history - 1, len(dates)): + seq_dates = dates[idx - self.history + 1 : idx + 1] + seq_arrays = [] + for date in seq_dates: + x = _spatial_downsample(np.load(self.cache_root / "met" / f"{date}.npy"), self.downsample_factor) + x = _normalize_weather(x, mean, std) + if static_channels is not None: + x = np.concatenate([x, static_channels], axis=0) + seq_arrays.append(x.astype(np.float32)) + x_rows.append(np.stack(seq_arrays, axis=0).astype(np.float32)) + target = _spatial_downsample(np.load(self.cache_root / "labels" / f"{dates[idx]}.npy"), self.downsample_factor) + y_rows.append(target[None, :, :].astype(np.float32)) + used_dates.append(dates[idx]) + + if not x_rows: + raise ValueError( + f"Temporal split '{split_name}' has no usable samples. Need at least history={self.history} dates, got {len(dates)}." + ) + + x_np = np.stack(x_rows, axis=0).astype(np.float32) + y_np = np.stack(y_rows, axis=0).astype(np.float32) + splits[split_name] = DataSplit( + inputs=torch.from_numpy(x_np), + targets=torch.from_numpy(y_np), + metadata={"dates": used_dates, "history": self.history}, + ) + + sample_shape = splits["train"].inputs.shape + return DataBundle( + splits=splits, + feature_spec=FeatureSpec( + channels=int(sample_shape[2]), + description="Temporal weather histories for wildfire occurrence prediction.", + extra={ + "history": self.history, + "height": int(sample_shape[3]), + "width": int(sample_shape[4]), + "downsample_factor": self.downsample_factor, + "weather_vars": weather_vars, + "static_feature_names": ["fuel_class_scaled", "fuel_valid_mask"] if static_channels is not None else [], + }, + ), + label_spec=LabelSpec( + num_targets=1, + task_type="segmentation", + description="Binary daily wildfire occurrence grid for the last frame in each history window.", + ), + metadata={ + "dataset": self.name, + "cache_root": str(self.cache_root), + "has_static_fuel": fuel is not None, + "normalization": { + "mean": mean.tolist(), + "std": std.tolist(), + "fit_split": "train", + }, + "splits": {k: int(v.inputs.shape[0]) for k, v in splits.items()}, + }, + ) + + +class WildfireTrackO2024TabularDataset(_WildfireTrackOBase): + name = "wildfire_track_o_2024_tabular" + + def __init__( + self, + cache_dir: str | None = None, + *, + downsample_factor: int = 8, + include_coords: bool = True, + include_day_of_year: bool = True, + train_limit_days: int | None = None, + val_limit_days: int | None = None, + test_limit_days: int | None = None, + ): + super().__init__( + cache_dir=cache_dir, + downsample_factor=downsample_factor, + train_limit_days=train_limit_days, + val_limit_days=val_limit_days, + test_limit_days=test_limit_days, + ) + self.include_coords = bool(include_coords) + self.include_day_of_year = bool(include_day_of_year) + + def _load(self) -> DataBundle: + split_dates = self._load_split_dates() + weather_vars = _load_weather_vars(self.cache_root) + lat, lon = _load_lat_lon(self.cache_root) + mean, std = _compute_channel_stats(self.cache_root, split_dates["train"], self.downsample_factor) + fuel, fuel_mask = _load_static_fuel(self.cache_root, self.downsample_factor) + + sample_met = _spatial_downsample(np.load(self.cache_root / "met" / f"{split_dates['train'][0]}.npy"), self.downsample_factor) + lat = lat[:: self.downsample_factor][: sample_met.shape[1]] + lon = lon[:: self.downsample_factor][: sample_met.shape[2]] + lat_grid, lon_grid = np.meshgrid(lat, lon, indexing="ij") + coord_block = np.stack([lat_grid, lon_grid], axis=-1).reshape(-1, 2).astype(np.float32) + coord_mean = coord_block.mean(axis=0, keepdims=True) + coord_std = coord_block.std(axis=0, keepdims=True) + 1e-6 + coord_block = (coord_block - coord_mean) / coord_std + + splits: dict[str, DataSplit] = {} + for split_name, dates in split_dates.items(): + x_rows: list[np.ndarray] = [] + y_rows: list[np.ndarray] = [] + row_dates: list[str] = [] + + for date in dates: + met = _spatial_downsample(np.load(self.cache_root / "met" / f"{date}.npy"), self.downsample_factor) + met = _normalize_weather(met, mean, std) + features = met.reshape(met.shape[0], -1).T.astype(np.float32) + + extras: list[np.ndarray] = [] + if self.include_coords: + extras.append(coord_block) + if self.include_day_of_year: + sin_doy, cos_doy = _date_to_cyclical_features(date) + extras.append( + np.repeat( + np.asarray([[sin_doy, cos_doy]], dtype=np.float32), + repeats=features.shape[0], + axis=0, + ) + ) + if extras: + features = np.concatenate([features, *extras], axis=1) + if fuel is not None and fuel_mask is not None: + fuel_cols = fuel.reshape(-1, 1).astype(np.float32) + fuel_mask_cols = fuel_mask.reshape(-1, 1).astype(np.float32) + features = np.concatenate([features, fuel_cols, fuel_mask_cols], axis=1) + + label = _spatial_downsample(np.load(self.cache_root / "labels" / f"{date}.npy"), self.downsample_factor) + labels = label.reshape(-1).astype(np.float32) + + x_rows.append(features) + y_rows.append(labels) + row_dates.extend([date] * features.shape[0]) + + x_np = np.concatenate(x_rows, axis=0).astype(np.float32) + y_np = np.concatenate(y_rows, axis=0).astype(np.float32) + splits[split_name] = DataSplit( + inputs=torch.from_numpy(x_np), + targets=torch.from_numpy(y_np), + metadata={"row_dates": row_dates}, + ) + + feature_names = list(weather_vars) + if self.include_coords: + feature_names.extend(["lat", "lon"]) + if self.include_day_of_year: + feature_names.extend(["sin_doy", "cos_doy"]) + if fuel is not None and fuel_mask is not None: + feature_names.extend(["fuel_class_scaled", "fuel_valid_mask"]) + + return DataBundle( + splits=splits, + feature_spec=FeatureSpec( + input_dim=int(splits["train"].inputs.shape[1]), + description="Tabularized wildfire occurrence features from daily gridded cache values.", + extra={ + "downsample_factor": self.downsample_factor, + "feature_names": feature_names, + }, + ), + label_spec=LabelSpec( + num_targets=1, + task_type="classification", + description="Binary wildfire occurrence for each grid cell and day in tabular form.", + ), + metadata={ + "dataset": self.name, + "cache_root": str(self.cache_root), + "has_static_fuel": fuel is not None, + "normalization": { + "weather_mean": mean.tolist(), + "weather_std": std.tolist(), + "fit_split": "train", + }, + "splits": {k: int(v.inputs.shape[0]) for k, v in splits.items()}, + }, + ) + + +__all__ = [ + "TrackOSplitConfig", + "WildfireTrackO2024RasterDataset", + "WildfireTrackO2024TemporalDataset", + "WildfireTrackO2024TabularDataset", +] diff --git a/pyhazards/models/__init__.py b/pyhazards/models/__init__.py index ea923edc..55d0f5cb 100644 --- a/pyhazards/models/__init__.py +++ b/pyhazards/models/__init__.py @@ -5,6 +5,30 @@ from .eqnet import EQNet, eqnet_builder from .eqtransformer import EQTransformer, eqtransformer_builder from .firecastnet import FireCastNet, firecastnet_builder +from .firemm_ir import FireMMIR, firemm_ir_builder +from .firepred import FirePred, firepred_builder +from .gemini_25_pro_wildfire_prompted import ( + Gemini25ProWildfirePrompted, + gemini_25_pro_wildfire_prompted_builder, +) +from .internvl3_wildfire_prompted import ( + InternVL3WildfirePrompted, + internvl3_wildfire_prompted_builder, +) +from .llama4_wildfire_prompted import ( + Llama4WildfirePrompted, + llama4_wildfire_prompted_builder, +) +from .modis_active_fire_c61 import MODISActiveFireC61, modis_active_fire_c61_builder +from .prithvi_burnscars import PrithviBurnScars, prithvi_burnscars_builder +from .prithvi_eo_2_tl import PrithviEO2TL, prithvi_eo_2_tl_builder +from .prithvi_wxc import PrithviWxC, prithvi_wxc_builder +from .qwen25_vl_wildfire_prompted import ( + Qwen25VLWildfirePrompted, + qwen25_vl_wildfire_prompted_builder, +) +from .ts_satfire import TSSatFire, ts_satfire_builder +from .viirs_375m_active_fire import VIIRS375mActiveFire, viirs_375m_active_fire_builder from .floodcast import FloodCast, floodcast_builder from .forefire import ForeFireAdapter, forefire_builder from .fourcastnet_tc import FourCastNetTC, fourcastnet_tc_builder @@ -31,10 +55,32 @@ WavefieldMetrics, wavecastnet_builder, ) -from .wildfire_forecasting import WildfireForecasting, wildfire_forecasting_builder from .wildfire_aspp import TverskyLoss, WildfireASPP, wildfire_aspp_builder from .wildfire_fpa import WildfireFPA, wildfire_fpa_builder from .wildfire_mamba import WildfireMamba, wildfire_mamba_builder +from .wildfiregpt import WildfireGPTReasoner, wildfiregpt_builder +from .logistic_regression import LogisticRegressionModel, logistic_regression_builder +from .random_forest import RandomForestModel, random_forest_builder +from .xgboost import XGBoostModel, xgboost_builder +from .lightgbm import LightGBMModel, lightgbm_builder +from .unet import TinyUNet, unet_builder +from .resnet18_unet import TinyResNet18UNet, resnet18_unet_builder +from .attention_unet import TinyAttentionUNet, attention_unet_builder +from .deeplabv3p import TinyDeepLabV3P, deeplabv3p_builder +from .convlstm import TinyConvLSTM, convlstm_builder +from .mau import TinyMAU, mau_builder +from .predrnn_v2 import TinyPredRNNv2, predrnn_v2_builder +from .rainformer import TinyRainformer, rainformer_builder +from .earthformer import TinyEarthFormer, earthformer_builder +from .swinlstm import TinySwinLSTM, swinlstm_builder +from .earthfarseer import TinyEarthFarseer, earthfarseer_builder +from .convgru_trajgru import TinyConvGRTrajGRU, convgru_trajgru_builder +from .tcn import TinyTCN, tcn_builder +from .utae import TinyUTAE, utae_builder +from .segformer import TinySegFormer, segformer_builder +from .swin_unet import TinySwinUNet, swin_unet_builder +from .vit_segmenter import TinyViTSegmenter, vit_segmenter_builder +from .deep_ensemble import DeepEnsemble, deep_ensemble_builder from .wildfirespreadts import WildfireSpreadTS, wildfirespreadts_builder from .wrf_sfire import WRFSFireAdapter, wrf_sfire_builder @@ -57,6 +103,30 @@ "eqtransformer_builder", "FireCastNet", "firecastnet_builder", + "FireMMIR", + "firemm_ir_builder", + "FirePred", + "firepred_builder", + "Gemini25ProWildfirePrompted", + "gemini_25_pro_wildfire_prompted_builder", + "InternVL3WildfirePrompted", + "internvl3_wildfire_prompted_builder", + "Llama4WildfirePrompted", + "llama4_wildfire_prompted_builder", + "MODISActiveFireC61", + "modis_active_fire_c61_builder", + "PrithviBurnScars", + "prithvi_burnscars_builder", + "PrithviEO2TL", + "prithvi_eo_2_tl_builder", + "PrithviWxC", + "prithvi_wxc_builder", + "Qwen25VLWildfirePrompted", + "qwen25_vl_wildfire_prompted_builder", + "TSSatFire", + "ts_satfire_builder", + "VIIRS375mActiveFire", + "viirs_375m_active_fire_builder", "FloodCast", "floodcast_builder", "ForeFireAdapter", @@ -97,12 +167,12 @@ "wildfire_aspp_builder", "WildfireCNNASPP", "cnn_aspp_builder", - "WildfireForecasting", - "wildfire_forecasting_builder", "WildfireFPA", "wildfire_fpa_builder", "WildfireMamba", "wildfire_mamba_builder", + "WildfireGPTReasoner", + "wildfiregpt_builder", "WildfireSpreadTS", "wildfirespreadts_builder", "WRFSFireAdapter", @@ -137,15 +207,17 @@ "wildfire_fpa", wildfire_fpa_builder, defaults={ - "out_dim": 5, - "output_dim": 5, + "in_dim": 8, + "out_dim": 1, + "input_dim": 7, + "output_dim": 1, "depth": 2, "hidden_dim": 64, "activation": "relu", - "dropout": None, + "dropout": 0.1, "latent_dim": 32, - "num_layers": 1, - "lookback": 50, + "num_layers": 2, + "lookback": 12, }, ) @@ -169,28 +241,27 @@ defaults={"in_channels": 12}, ) -register_model( - "wildfire_forecasting", - wildfire_forecasting_builder, - defaults={ - "input_dim": 7, - "hidden_dim": 64, - "output_dim": 5, - "lookback": 12, - "num_layers": 2, - "dropout": 0.1, - }, -) - register_model( "asufm", asufm_builder, defaults={ - "input_dim": 7, - "hidden_dim": 64, - "output_dim": 5, - "lookback": 12, - "dropout": 0.1, + "image_size": 64, + "patch_size": 4, + "in_channels": 6, + "out_dim": 1, + "embed_dim": 96, + "depths": (2, 2, 2, 2), + "num_heads": (3, 6, 12, 24), + "window_size": 8, + "mlp_ratio": 4.0, + "dropout": 0.0, + "drop_path_rate": 0.1, + "focal_window": 3, + "focal_level": 2, + "use_focal_modulation": True, + "spatial_attention": True, + "skip_num": 3, + "use_checkpoint": True, }, ) @@ -237,6 +308,196 @@ }, ) +register_model( + "firepred", + firepred_builder, + defaults={ + "history": 5, + "in_channels": 8, + "hidden_dim": 32, + "out_channels": 1, + "dropout": 0.1, + }, +) + +register_model( + "modis_active_fire_c61", + modis_active_fire_c61_builder, + defaults={ + "in_channels": 5, + "hidden_dim": 24, + "out_dim": 1, + "context_kernel": 9, + "dropout": 0.1, + }, +) + +register_model( + "prithvi_eo_2_tl", + prithvi_eo_2_tl_builder, + defaults={ + "image_size": 32, + "in_channels": 6, + "out_dim": 1, + "patch_size": 4, + "embed_dim": 128, + "depth": 4, + "num_heads": 4, + "mlp_ratio": 4.0, + "dropout": 0.1, + "time_dim": 1, + "location_dim": 2, + "decoder_channels": 64, + }, +) + +register_model( + "prithvi_burnscars", + prithvi_burnscars_builder, + defaults={ + "image_size": 32, + "in_channels": 6, + "out_dim": 1, + "patch_size": 4, + "embed_dim": 128, + "depth": 4, + "num_heads": 4, + "mlp_ratio": 4.0, + "dropout": 0.1, + "time_dim": 1, + "location_dim": 2, + "decoder_channels": 64, + }, +) + +register_model( + "prithvi_wxc", + prithvi_wxc_builder, + defaults={ + "image_size": 32, + "in_channels": 8, + "out_dim": 1, + "patch_size": 4, + "embed_dim": 128, + "depth": 4, + "num_heads": 4, + "mlp_ratio": 4.0, + "dropout": 0.1, + "lead_time_dim": 1, + "variable_summary_dim": 8, + "decoder_channels": 64, + }, +) + +register_model( + "gemini_25_pro_wildfire_prompted", + gemini_25_pro_wildfire_prompted_builder, + defaults={ + "in_channels": 6, + "out_dim": 1, + "hidden_dim": 96, + "prompt_dim": 32, + "num_prompt_tokens": 6, + "num_heads": 8, + "dropout": 0.1, + }, +) + +register_model( + "internvl3_wildfire_prompted", + internvl3_wildfire_prompted_builder, + defaults={ + "in_channels": 6, + "out_dim": 1, + "hidden_dim": 96, + "prompt_dim": 32, + "num_prompt_tokens": 5, + "num_heads": 6, + "dropout": 0.1, + }, +) + +register_model( + "llama4_wildfire_prompted", + llama4_wildfire_prompted_builder, + defaults={ + "in_channels": 6, + "out_dim": 1, + "hidden_dim": 80, + "prompt_dim": 32, + "num_prompt_tokens": 4, + "num_heads": 8, + "dropout": 0.1, + }, +) + +register_model( + "qwen25_vl_wildfire_prompted", + qwen25_vl_wildfire_prompted_builder, + defaults={ + "in_channels": 6, + "out_dim": 1, + "hidden_dim": 64, + "prompt_dim": 24, + "num_prompt_tokens": 4, + "num_heads": 4, + "dropout": 0.1, + }, +) + +register_model( + "ts_satfire", + ts_satfire_builder, + defaults={ + "history": 5, + "in_channels": 8, + "hidden_dim": 32, + "out_channels": 1, + "dropout": 0.1, + }, +) + +register_model( + "viirs_375m_active_fire", + viirs_375m_active_fire_builder, + defaults={ + "in_channels": 5, + "hidden_dim": 24, + "out_dim": 1, + "context_kernel": 7, + "dropout": 0.1, + }, +) + +register_model( + "wildfiregpt", + wildfiregpt_builder, + defaults={ + "in_channels": 12, + "out_dim": 1, + "base_channels": 32, + "hidden_dim": 64, + "profile_dim": 8, + "retrieved_dim": 16, + "num_heads": 4, + "dropout": 0.1, + }, +) + +register_model( + "firemm_ir", + firemm_ir_builder, + defaults={ + "in_channels": 6, + "out_dim": 1, + "hidden_dim": 64, + "instruction_dim": 16, + "num_memory_slots": 3, + "num_heads": 4, + "dropout": 0.1, + }, +) + register_model( "wildfire_cnn_aspp", cnn_aspp_builder, @@ -473,3 +734,287 @@ "dropout": 0.1, }, ) + + +__all__.extend([ + "LogisticRegressionModel", "logistic_regression_builder", + "RandomForestModel", "random_forest_builder", + "XGBoostModel", "xgboost_builder", + "LightGBMModel", "lightgbm_builder", + "TinyUNet", "unet_builder", + "TinyResNet18UNet", "resnet18_unet_builder", + "TinyAttentionUNet", "attention_unet_builder", + "TinyDeepLabV3P", "deeplabv3p_builder", + "TinyConvLSTM", "convlstm_builder", + "TinyMAU", "mau_builder", + "TinyPredRNNv2", "predrnn_v2_builder", + "TinyRainformer", "rainformer_builder", + "TinyEarthFormer", "earthformer_builder", + "TinySwinLSTM", "swinlstm_builder", + "TinyEarthFarseer", "earthfarseer_builder", + "TinyConvGRTrajGRU", "convgru_trajgru_builder", + "TinyTCN", "tcn_builder", + "TinyUTAE", "utae_builder", + "TinySegFormer", "segformer_builder", + "TinySwinUNet", "swin_unet_builder", + "TinyViTSegmenter", "vit_segmenter_builder", + "DeepEnsemble", "deep_ensemble_builder", +]) + + +register_model( + "logistic_regression", + logistic_regression_builder, + defaults={ + "solver": "lbfgs", + "max_iter": 500, + "class_weight": "balanced", + }, +) + +register_model( + "random_forest", + random_forest_builder, + defaults={ + "n_estimators": 500, + "max_depth": None, + "class_weight": "balanced_subsample", + }, +) + +register_model( + "xgboost", + xgboost_builder, + defaults={ + "max_depth": 8, + "eta": 0.05, + "subsample": 0.8, + "colsample_bytree": 0.8, + "num_boost_round": 800, + }, +) + +register_model( + "lightgbm", + lightgbm_builder, + defaults={ + "num_leaves": 63, + "learning_rate": 0.05, + "feature_fraction": 0.8, + "bagging_fraction": 0.8, + "num_boost_round": 800, + }, +) + +register_model( + "unet", + unet_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "base_channels": 16, + }, +) + +register_model( + "resnet18_unet", + resnet18_unet_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "stem_channels": 16, + }, +) + +register_model( + "attention_unet", + attention_unet_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "base_channels": 8, + }, +) + +register_model( + "deeplabv3p", + deeplabv3p_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "base_channels": 16, + }, +) + +register_model( + "convlstm", + convlstm_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "enc_channels": 16, + "hidden_channels": 16, + "num_layers": 2, + "kernel_size": 3, + }, +) + +register_model( + "mau", + mau_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "hidden_channels": 12, + }, +) + +register_model( + "predrnn_v2", + predrnn_v2_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "hidden_channels": 12, + }, +) + +register_model( + "rainformer", + rainformer_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "hidden_channels": 16, + "num_heads": 4, + "num_layers": 2, + }, +) + +register_model( + "earthformer", + earthformer_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "hidden_channels": 16, + "num_heads": 4, + "num_layers": 2, + }, +) + +register_model( + "swinlstm", + swinlstm_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "embed_dim": 16, + "hidden_channels": 16, + "num_heads": 4, + "window_size": 3, + }, +) + +register_model( + "earthfarseer", + earthfarseer_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "hidden_channels": 16, + "num_heads": 4, + "num_layers": 2, + }, +) + +register_model( + "convgru_trajgru", + convgru_trajgru_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "enc_channels": 16, + "hidden_channels": 16, + "kernel_size": 3, + }, +) + +register_model( + "tcn", + tcn_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "embed_dim": 16, + "hidden_channels": 16, + "kernel_size": 3, + "num_levels": 3, + "dropout": 0.1, + }, +) + +register_model( + "utae", + utae_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "hidden_channels": 16, + "num_heads": 4, + }, +) + +register_model( + "segformer", + segformer_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "embed_dims": (16, 32), + "num_heads": (1, 2), + "sr_ratios": (4, 2), + "mlp_ratio": 2.0, + "dropout": 0.1, + }, +) + +register_model( + "swin_unet", + swin_unet_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "embed_dims": (16, 32), + "num_heads": (1, 2), + "window_size": 3, + "mlp_ratio": 2.0, + "dropout": 0.1, + }, +) + +register_model( + "vit_segmenter", + vit_segmenter_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "patch_size": 4, + "embed_dim": 64, + "depth": 4, + "num_heads": 4, + "mlp_ratio": 2.0, + "dropout": 0.1, + }, +) + +register_model( + "deep_ensemble", + deep_ensemble_builder, + defaults={ + "in_channels": 1, + "out_dim": 1, + "base_channels": 8, + "ensemble_size": 5, + }, +) diff --git a/pyhazards/models/_wildfire_benchmark_utils.py b/pyhazards/models/_wildfire_benchmark_utils.py new file mode 100644 index 00000000..9b8d0b8e --- /dev/null +++ b/pyhazards/models/_wildfire_benchmark_utils.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import inspect +from typing import Any, Optional + +import numpy as np +import torch +import torch.nn as nn + +from ..datasets.base import DataBundle + + +def require_task(task: str, allowed: set[str], model_name: str) -> None: + normalized = task.lower() + if normalized not in allowed: + allowed_text = ", ".join(sorted(allowed)) + raise ValueError(f"Model '{model_name}' does not support task={task!r}. Allowed tasks: {allowed_text}") + + +def filter_init_kwargs(callable_obj: Any, kwargs: dict[str, Any]) -> dict[str, Any]: + sig = inspect.signature(callable_obj) + accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + if accepts_kwargs: + return dict(kwargs) + allowed = {name for name in sig.parameters if name != 'self'} + return {k: v for k, v in kwargs.items() if k in allowed} + + +class SegmentationPort(nn.Module): + def __init__(self, model: nn.Module, out_channels: int = 1): + super().__init__() + self.model = model + self.out_channels = int(out_channels) + self.output_head = nn.Identity() if self.out_channels == 1 else nn.Conv2d(1, self.out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + logits = self.model(x) + return self.output_head(logits) + + +def _to_numpy_2d(x: torch.Tensor) -> np.ndarray: + if not isinstance(x, torch.Tensor): + raise TypeError('Expected torch.Tensor inputs for estimator-based models.') + x_np = x.detach().cpu().float().numpy() + if x_np.ndim == 1: + x_np = x_np[:, None] + if x_np.ndim > 2: + x_np = x_np.reshape(x_np.shape[0], -1) + return x_np + + +def _to_numpy_labels(y: torch.Tensor) -> np.ndarray: + if not isinstance(y, torch.Tensor): + raise TypeError('Expected torch.Tensor targets for estimator-based models.') + y_np = y.detach().cpu().numpy() + if y_np.ndim > 1: + y_np = y_np.reshape(y_np.shape[0], -1) + if y_np.shape[1] != 1: + raise ValueError('Estimator-based models expect a single target column.') + y_np = y_np[:, 0] + return y_np.astype(np.int64) + + +class EstimatorPort(nn.Module): + def __init__(self): + super().__init__() + self._is_fitted = False + + def fit_bundle( + self, + data: DataBundle, + train_split: str = 'train', + val_split: Optional[str] = None, + **_: Any, + ) -> None: + train_data = data.get_split(train_split) + x_train = _to_numpy_2d(train_data.inputs) + y_train = _to_numpy_labels(train_data.targets) + + x_val = None + y_val = None + if val_split: + val_data = data.get_split(val_split) + x_val = _to_numpy_2d(val_data.inputs) + y_val = _to_numpy_labels(val_data.targets) + + self._fit_numpy(x_train=x_train, y_train=y_train, x_val=x_val, y_val=y_val) + self._is_fitted = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self._is_fitted: + raise RuntimeError( + f'{self.__class__.__name__} has not been fitted. Use Trainer.fit(...) with a tensor-backed DataBundle first.' + ) + x_np = _to_numpy_2d(x) + probs_pos = self._predict_positive_proba(x_np) + probs = np.stack([1.0 - probs_pos, probs_pos], axis=-1).astype(np.float32) + return torch.from_numpy(probs).to(x.device) + + def _fit_numpy( + self, + x_train: np.ndarray, + y_train: np.ndarray, + x_val: Optional[np.ndarray], + y_val: Optional[np.ndarray], + ) -> None: + raise NotImplementedError + + def _predict_positive_proba(self, x: np.ndarray) -> np.ndarray: + raise NotImplementedError diff --git a/pyhazards/models/asufm.py b/pyhazards/models/asufm.py index 76b51df6..f19fff14 100644 --- a/pyhazards/models/asufm.py +++ b/pyhazards/models/asufm.py @@ -1,81 +1,582 @@ from __future__ import annotations +from typing import Sequence + import torch import torch.nn as nn +import torch.nn.functional as F -class ASUFM(nn.Module): - """Temporal convolution baseline for wildfire activity forecasting.""" +class DropPath(nn.Module): + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = float(drop_prob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1.0 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + return x * random_tensor + + +def _window_partition(x: torch.Tensor, window_size: int) -> tuple[torch.Tensor, int, int]: + b, h, w, c = x.shape + pad_h = (window_size - h % window_size) % window_size + pad_w = (window_size - w % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = x.permute(0, 3, 1, 2).contiguous() + x = F.pad(x, (0, pad_w, 0, pad_h)) + x = x.permute(0, 2, 3, 1).contiguous() + hp, wp = h + pad_h, w + pad_w + x = x.view(b, hp // window_size, window_size, wp // window_size, window_size, c) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + return x.view(-1, window_size * window_size, c), hp, wp + +def _window_reverse(windows: torch.Tensor, window_size: int, hp: int, wp: int, batch_size: int) -> torch.Tensor: + channels = windows.shape[-1] + x = windows.view(batch_size, hp // window_size, wp // window_size, window_size, window_size, channels) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + return x.view(batch_size, hp, wp, channels) + + +class PatchEmbed(nn.Module): def __init__( self, - input_dim: int = 7, - hidden_dim: int = 64, - output_dim: int = 5, - lookback: int = 12, - dropout: float = 0.1, + image_size: int = 64, + patch_size: int = 4, + in_channels: int = 6, + embed_dim: int = 96, + patch_norm: bool = True, ): super().__init__() - if input_dim <= 0: - raise ValueError(f"input_dim must be positive, got {input_dim}") - if hidden_dim <= 0: - raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") - if output_dim <= 0: - raise ValueError(f"output_dim must be positive, got {output_dim}") - if lookback <= 0: - raise ValueError(f"lookback must be positive, got {lookback}") - if not 0.0 <= dropout < 1.0: - raise ValueError(f"dropout must be in [0, 1), got {dropout}") - - self.lookback = int(lookback) - self.temporal = nn.Sequential( - nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1), + self.image_size = int(image_size) + self.patch_size = int(patch_size) + self.in_channels = int(in_channels) + self.embed_dim = int(embed_dim) + self.proj = nn.Conv2d( + self.in_channels, + self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + self.norm = nn.LayerNorm(self.embed_dim) if patch_norm else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.ndim != 4: + raise ValueError(f"PatchEmbed expected (B,C,H,W), got {tuple(x.shape)}") + _, channels, height, width = x.shape + if channels != self.in_channels: + raise ValueError(f"PatchEmbed expected {self.in_channels} channels, got {channels}") + if height != self.image_size or width != self.image_size: + raise ValueError( + f"PatchEmbed expected spatial size ({self.image_size}, {self.image_size}), " + f"got ({height}, {width})" + ) + x = self.proj(x) + x = x.permute(0, 2, 3, 1).contiguous() + x = self.norm(x) + return x.permute(0, 3, 1, 2).contiguous() + + +class PatchMerging(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.norm = nn.LayerNorm(4 * dim) + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + if h % 2 != 0 or w % 2 != 0: + raise ValueError(f"PatchMerging requires even spatial dims, got ({h}, {w})") + x = x.permute(0, 2, 3, 1).contiguous() + x0 = x[:, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, :] + x2 = x[:, 0::2, 1::2, :] + x3 = x[:, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3], dim=-1) + x = self.norm(x) + x = self.reduction(x) + return x.permute(0, 3, 1, 2).contiguous() + + +class PatchExpand(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.expand = nn.Conv2d(dim, 2 * dim, kernel_size=1, bias=False) + self.norm = nn.LayerNorm(dim // 2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.expand(x) + x = F.pixel_shuffle(x, upscale_factor=2) + x = x.permute(0, 2, 3, 1).contiguous() + x = self.norm(x) + return x.permute(0, 3, 1, 2).contiguous() + + +class FinalPatchExpandX4(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.expand = nn.Conv2d(dim, 16 * dim, kernel_size=1, bias=False) + self.norm = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.expand(x) + x = F.pixel_shuffle(x, upscale_factor=4) + x = x.permute(0, 2, 3, 1).contiguous() + x = self.norm(x) + return x.permute(0, 3, 1, 2).contiguous() + + +class FocalModulation(nn.Module): + def __init__( + self, + dim: int, + focal_window: int = 3, + focal_level: int = 2, + dropout: float = 0.0, + ): + super().__init__() + self.dim = int(dim) + self.focal_level = int(focal_level) + self.proj = nn.Linear(self.dim, 2 * self.dim + self.focal_level + 1) + self.depthwise_layers = nn.ModuleList( + [ + nn.Conv2d( + self.dim, + self.dim, + kernel_size=focal_window + 2 * level, + padding=(focal_window + 2 * level) // 2, + groups=self.dim, + bias=False, + ) + for level in range(self.focal_level) + ] + ) + self.mix = nn.Conv2d(self.dim, self.dim, kernel_size=1, bias=False) + self.out = nn.Linear(self.dim, self.dim) + self.drop = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.ndim != 4: + raise ValueError(f"FocalModulation expected (B,H,W,C), got {tuple(x.shape)}") + gates_dim = self.focal_level + 1 + projected = self.proj(x) + q, ctx, gates = torch.split(projected, [self.dim, self.dim, gates_dim], dim=-1) + ctx = ctx.permute(0, 3, 1, 2).contiguous() + gates = gates.permute(0, 3, 1, 2).contiguous() + + aggregated = 0.0 + for level, layer in enumerate(self.depthwise_layers): + aggregated = aggregated + layer(ctx) * gates[:, level : level + 1] + global_ctx = ctx.mean(dim=(2, 3), keepdim=True) + aggregated = aggregated + global_ctx * gates[:, -1:] + + modulator = self.mix(aggregated).permute(0, 2, 3, 1).contiguous() + out = q * modulator + out = self.out(out) + return self.drop(out) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_ratio: float = 4.0, dropout: float = 0.0): + super().__init__() + hidden_dim = int(dim * mlp_ratio) + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), nn.GELU(), - nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class SwinFocalBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + window_size: int, + shift_size: int, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + drop_path: float = 0.0, + use_focal: bool = False, + focal_window: int = 3, + focal_level: int = 2, + ): + super().__init__() + self.dim = int(dim) + self.window_size = int(window_size) + self.shift_size = int(shift_size) + self.use_focal = bool(use_focal) + self.norm1 = nn.LayerNorm(self.dim) + self.focal = ( + FocalModulation( + dim=self.dim, + focal_window=focal_window, + focal_level=focal_level, + dropout=dropout, + ) + if self.use_focal + else None + ) + self.attn = nn.MultiheadAttention(self.dim, num_heads=int(num_heads), dropout=dropout, batch_first=True) + self.drop_path = DropPath(drop_path) + self.norm2 = nn.LayerNorm(self.dim) + self.mlp = MLP(dim=self.dim, mlp_ratio=mlp_ratio, dropout=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.ndim != 4: + raise ValueError(f"SwinFocalBlock expected (B,C,H,W), got {tuple(x.shape)}") + b, c, h, w = x.shape + if c != self.dim: + raise ValueError(f"SwinFocalBlock expected {self.dim} channels, got {c}") + + x_hw = x.permute(0, 2, 3, 1).contiguous() + x_hw = self.norm1(x_hw) + if self.focal is not None: + x_hw = self.focal(x_hw) + + effective_window = min(self.window_size, h, w) + effective_shift = 0 if effective_window <= 1 else min(self.shift_size, effective_window // 2) + if effective_shift > 0: + x_hw = torch.roll(x_hw, shifts=(-effective_shift, -effective_shift), dims=(1, 2)) + + windows, hp, wp = _window_partition(x_hw, effective_window) + attn_out, _ = self.attn(windows, windows, windows, need_weights=False) + windows = windows + attn_out + x_hw = _window_reverse(windows, effective_window, hp, wp, b) + + if effective_shift > 0: + x_hw = torch.roll(x_hw, shifts=(effective_shift, effective_shift), dims=(1, 2)) + x_hw = x_hw[:, :h, :w, :] + x_attn = x_hw.permute(0, 3, 1, 2).contiguous() + x = x + self.drop_path(x_attn) + + tokens = x.permute(0, 2, 3, 1).reshape(b, h * w, c).contiguous() + tokens = tokens + self.drop_path(self.mlp(self.norm2(tokens))) + return tokens.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() + + +class EncoderStage(nn.Module): + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + window_size: int, + mlp_ratio: float, + dropout: float, + drop_path_rates: Sequence[float], + use_focal: bool, + focal_window: int, + focal_level: int, + downsample: bool, + ): + super().__init__() + shift = max(1, window_size // 2) + self.blocks = nn.ModuleList( + [ + SwinFocalBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if block_idx % 2 == 0 else shift, + mlp_ratio=mlp_ratio, + dropout=dropout, + drop_path=drop_path_rates[block_idx], + use_focal=use_focal, + focal_window=focal_window, + focal_level=focal_level, + ) + for block_idx in range(depth) + ] + ) + self.downsample = PatchMerging(dim) if downsample else None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + for block in self.blocks: + x = block(x) + skip = x + if self.downsample is not None: + x = self.downsample(x) + return x, skip + + +class SpatialAttentionGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.project = nn.Sequential( + nn.Conv2d(2 * dim, dim, kernel_size=1, bias=False), nn.GELU(), + nn.Conv2d(dim, 1, kernel_size=1), ) - self.gate = nn.Sequential( - nn.AdaptiveAvgPool1d(1), - nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1), - nn.Sigmoid(), + + def forward(self, skip: torch.Tensor, gating: torch.Tensor) -> torch.Tensor: + if gating.shape[-2:] != skip.shape[-2:]: + gating = F.interpolate(gating, size=skip.shape[-2:], mode="bilinear", align_corners=False) + mask = torch.sigmoid(self.project(torch.cat([skip, gating], dim=1))) + return skip * mask + + +class DecoderStage(nn.Module): + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + window_size: int, + mlp_ratio: float, + dropout: float, + drop_path_rates: Sequence[float], + spatial_attention: bool, + ): + super().__init__() + self.skip_gate = SpatialAttentionGate(dim) if spatial_attention else None + self.concat_proj = nn.Conv2d(2 * dim, dim, kernel_size=1, bias=False) + shift = max(1, window_size // 2) + self.blocks = nn.ModuleList( + [ + SwinFocalBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if block_idx % 2 == 0 else shift, + mlp_ratio=mlp_ratio, + dropout=dropout, + drop_path=drop_path_rates[block_idx], + use_focal=False, + ) + for block_idx in range(depth) + ] ) - self.head = nn.Sequential( - nn.Dropout(dropout), - nn.Linear(hidden_dim, output_dim), + + def forward(self, x: torch.Tensor, skip: torch.Tensor | None) -> torch.Tensor: + if skip is not None: + if self.skip_gate is not None: + skip = self.skip_gate(skip, x) + x = self.concat_proj(torch.cat([x, skip], dim=1)) + for block in self.blocks: + x = block(x) + return x + + +class ASUFM(nn.Module): + """ + Self-contained ASUFM port for PyHazards. + + This implementation follows the official ASUFM design at a high level: + patch embedding, hierarchical Swin-style encoder stages, focal modulation in + the encoder, and an attention-gated U-Net-style decoder. It intentionally + avoids external dependencies such as `timm` and `einops` so the model can be + built directly inside the main PyHazards library. + """ + + def __init__( + self, + image_size: int = 64, + patch_size: int = 4, + in_channels: int = 6, + out_dim: int = 1, + embed_dim: int = 96, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), + window_size: int = 8, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + drop_path_rate: float = 0.1, + focal_window: int = 3, + focal_level: int = 2, + use_focal_modulation: bool = True, + spatial_attention: bool = True, + skip_num: int = 3, + use_checkpoint: bool = False, + ): + super().__init__() + _ = use_checkpoint + + if len(depths) != 4: + raise ValueError(f"ASUFM expects 4 encoder depths, got {tuple(depths)}") + if len(num_heads) != len(depths): + raise ValueError("num_heads must have the same length as depths") + if skip_num < 0 or skip_num > 3: + raise ValueError(f"skip_num must be in [0, 3], got {skip_num}") + + self.image_size = int(image_size) + self.patch_size = int(patch_size) + self.in_channels = int(in_channels) + self.out_dim = int(out_dim) + self.skip_num = int(skip_num) + + dims = [int(embed_dim * (2**idx)) for idx in range(len(depths))] + for dim, heads in zip(dims, num_heads): + if dim % int(heads) != 0: + raise ValueError(f"Channel dim {dim} must be divisible by num_heads={heads}") + + total_blocks = sum(int(depth) for depth in depths) + drop_path_values = torch.linspace(0.0, float(drop_path_rate), total_blocks).tolist() + + self.patch_embed = PatchEmbed( + image_size=self.image_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=int(embed_dim), + patch_norm=True, + ) + + cursor = 0 + self.encoder_stages = nn.ModuleList() + for stage_idx, (dim, depth, heads) in enumerate(zip(dims, depths, num_heads)): + stage_dpr = drop_path_values[cursor : cursor + depth] + cursor += depth + self.encoder_stages.append( + EncoderStage( + dim=dim, + depth=int(depth), + num_heads=int(heads), + window_size=int(window_size), + mlp_ratio=float(mlp_ratio), + dropout=float(dropout), + drop_path_rates=stage_dpr, + use_focal=bool(use_focal_modulation), + focal_window=int(focal_window), + focal_level=int(focal_level), + downsample=stage_idx < len(depths) - 1, + ) + ) + + reverse_depths = list(reversed(depths[:-1])) + reverse_heads = list(reversed(num_heads[:-1])) + reverse_dims = list(reversed(dims[:-1])) + reverse_drop_paths = list(reversed(drop_path_values[:-depths[-1]])) + + self.upsamplers = nn.ModuleList( + [ + PatchExpand(dim=dims[-1]), + PatchExpand(dim=dims[-2]), + PatchExpand(dim=dims[-3]), + ] ) + self.decoder_stages = nn.ModuleList() + cursor = 0 + for dim, depth, heads in zip(reverse_dims, reverse_depths, reverse_heads): + stage_dpr = reverse_drop_paths[cursor : cursor + depth] + cursor += depth + self.decoder_stages.append( + DecoderStage( + dim=dim, + depth=int(depth), + num_heads=int(heads), + window_size=int(window_size), + mlp_ratio=float(mlp_ratio), + dropout=float(dropout), + drop_path_rates=stage_dpr, + spatial_attention=bool(spatial_attention), + ) + ) + + self.norm_up = nn.LayerNorm(dims[0]) + self.final_up = FinalPatchExpandX4(dim=dims[0]) + self.output_head = nn.Conv2d(dims[0], self.out_dim, kernel_size=1, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.ndim != 3: + if x.ndim != 4: + raise ValueError(f"ASUFM expected input of shape (B,C,H,W), got {tuple(x.shape)}") + _, channels, height, width = x.shape + if channels != self.in_channels: + raise ValueError(f"ASUFM expected {self.in_channels} input channels, got {channels}") + + required_factor = self.patch_size * (2 ** (len(self.encoder_stages) - 1)) + if height != self.image_size or width != self.image_size: raise ValueError( - "ASUFM expects input shape (batch, lookback, features), " - f"got {tuple(x.shape)}." + f"ASUFM expected image_size={self.image_size}, got spatial size ({height}, {width})" ) - if x.size(1) != self.lookback: - raise ValueError(f"ASUFM expected lookback={self.lookback}, got sequence length {x.size(1)}.") - encoded = self.temporal(x.transpose(1, 2)) - gated = encoded * self.gate(encoded) - pooled = torch.mean(gated, dim=-1) - return self.head(pooled) + if height % required_factor != 0 or width % required_factor != 0: + raise ValueError( + f"ASUFM requires H and W divisible by {required_factor}, got ({height}, {width})" + ) + + x = self.patch_embed(x) + skips: list[torch.Tensor] = [] + for stage_idx, stage in enumerate(self.encoder_stages): + x, skip = stage(x) + if stage_idx < len(self.encoder_stages) - 1: + skips.append(skip) + + for decoder_idx, (upsample, decoder_stage) in enumerate(zip(self.upsamplers, self.decoder_stages), start=1): + x = upsample(x) + skip = skips[-decoder_idx] if decoder_idx <= self.skip_num else None + x = decoder_stage(x, skip) + + x = x.permute(0, 2, 3, 1).contiguous() + x = self.norm_up(x) + x = x.permute(0, 3, 1, 2).contiguous() + x = self.final_up(x) + return self.output_head(x) def asufm_builder( task: str, - input_dim: int = 7, - hidden_dim: int = 64, - output_dim: int = 5, - lookback: int = 12, - dropout: float = 0.1, + image_size: int = 64, + patch_size: int = 4, + in_channels: int = 6, + out_dim: int = 1, + embed_dim: int = 96, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), + window_size: int = 8, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + drop_path_rate: float = 0.1, + focal_window: int = 3, + focal_level: int = 2, + use_focal_modulation: bool = True, + spatial_attention: bool = True, + skip_num: int = 3, + use_checkpoint: bool = False, + in_chans: int | None = None, + num_classes: int | None = None, + focal: bool | None = None, **kwargs, ) -> nn.Module: _ = kwargs - if task.lower() not in {"forecasting", "regression"}: - raise ValueError(f"asufm supports task='forecasting' or 'regression', got {task!r}.") + normalized_task = task.lower() + if normalized_task != "segmentation": + raise ValueError(f"ASUFM is segmentation-only. Got task='{task}'") + + if in_chans is not None: + in_channels = int(in_chans) + if num_classes is not None: + out_dim = int(num_classes) + if focal is not None: + use_focal_modulation = bool(focal) + return ASUFM( - input_dim=input_dim, - hidden_dim=hidden_dim, - output_dim=output_dim, - lookback=lookback, + image_size=image_size, + patch_size=patch_size, + in_channels=in_channels, + out_dim=out_dim, + embed_dim=embed_dim, + depths=tuple(int(v) for v in depths), + num_heads=tuple(int(v) for v in num_heads), + window_size=window_size, + mlp_ratio=mlp_ratio, dropout=dropout, + drop_path_rate=drop_path_rate, + focal_window=focal_window, + focal_level=focal_level, + use_focal_modulation=use_focal_modulation, + spatial_attention=spatial_attention, + skip_num=skip_num, + use_checkpoint=use_checkpoint, ) diff --git a/pyhazards/models/attention_unet.py b/pyhazards/models/attention_unet.py new file mode 100644 index 00000000..77486ce2 --- /dev/null +++ b/pyhazards/models/attention_unet.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.unet import ( + binary_ece, + make_synthetic_fire_maps, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .unet import ( + binary_ece, + make_synthetic_fire_maps, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class AttentionUNetTrackOConfig: + in_channels: int = 1 + base_channels: int = 8 + lr: float = 1e-3 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + +class AttentionGate(nn.Module): + def __init__(self, skip_channels: int, gate_channels: int, inter_channels: int): + super().__init__() + self.w_skip = nn.Conv2d(skip_channels, inter_channels, kernel_size=1) + self.w_gate = nn.Conv2d(gate_channels, inter_channels, kernel_size=1) + self.psi = nn.Conv2d(inter_channels, 1, kernel_size=1) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x_skip: torch.Tensor, x_gate: torch.Tensor) -> torch.Tensor: + alpha = self.relu(self.w_skip(x_skip) + self.w_gate(x_gate)) + alpha = self.sigmoid(self.psi(alpha)) + return x_skip * alpha + + +class TinyAttentionUNet(nn.Module): + def __init__(self, in_channels: int = 1, base_channels: int = 8): + super().__init__() + c1, c2, c3 = base_channels, base_channels * 2, base_channels * 4 + + self.enc1 = ConvBlock(in_channels, c1) + self.pool1 = nn.MaxPool2d(kernel_size=2) + + self.enc2 = ConvBlock(c1, c2) + self.pool2 = nn.MaxPool2d(kernel_size=2) + + self.bottleneck = ConvBlock(c2, c3) + + self.up2 = nn.ConvTranspose2d(c3, c2, kernel_size=2, stride=2) + self.att2 = AttentionGate(skip_channels=c2, gate_channels=c2, inter_channels=c2 // 2) + self.dec2 = ConvBlock(c2 + c2, c2) + + self.up1 = nn.ConvTranspose2d(c2, c1, kernel_size=2, stride=2) + self.att1 = AttentionGate(skip_channels=c1, gate_channels=c1, inter_channels=max(1, c1 // 2)) + self.dec1 = ConvBlock(c1 + c1, c1) + + self.head = nn.Conv2d(c1, 1, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.enc1(x) + x2 = self.enc2(self.pool1(x1)) + xb = self.bottleneck(self.pool2(x2)) + + y2 = self.up2(xb) + x2_att = self.att2(x2, y2) + y2 = torch.cat([y2, x2_att], dim=1) + y2 = self.dec2(y2) + + y1 = self.up1(y2) + x1_att = self.att1(x1, y1) + y1 = torch.cat([y1, x1_att], dim=1) + y1 = self.dec1(y1) + + return self.head(y1) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_attention_unet_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: AttentionUNetTrackOConfig, +): + if x_train.ndim != 4 or x_val.ndim != 4: + raise ValueError("x_train and x_val must be 4D arrays [N,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyAttentionUNet(in_channels=cfg.in_channels, base_channels=cfg.base_channels).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("Attention U-Net Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: AttentionUNetTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "attention_unet", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 192, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_maps(n_samples=n_samples, image_size=image_size, seed=seed) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = AttentionUNetTrackOConfig( + seed=seed, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_attention_unet_track_o( + x_train, y_train, x_val, y_val, cfg + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "attention_unet_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Attention U-Net Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=192) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"attention_unet_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] attention unet synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def attention_unet_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "attention_unet") + init_kwargs = filter_init_kwargs(TinyAttentionUNet, {"in_channels": int(in_channels), **kwargs}) + model = TinyAttentionUNet(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyAttentionUNet", "attention_unet_builder"] diff --git a/pyhazards/models/convgru_trajgru.py b/pyhazards/models/convgru_trajgru.py new file mode 100644 index 00000000..1066c1b8 --- /dev/null +++ b/pyhazards/models/convgru_trajgru.py @@ -0,0 +1,481 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class ConvGRTrajGRUTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + enc_channels: int = 16 + hidden_channels: int = 16 + kernel_size: int = 3 + lr: float = 3e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +def _normalized_base_grid(h: int, w: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + ys = torch.linspace(-1.0, 1.0, steps=h, device=device, dtype=dtype) + xs = torch.linspace(-1.0, 1.0, steps=w, device=device, dtype=dtype) + yy, xx = torch.meshgrid(ys, xs, indexing="ij") + return torch.stack((xx, yy), dim=-1) # [H,W,2] + + +def _warp_hidden(hidden: torch.Tensor, flow_xy: torch.Tensor) -> torch.Tensor: + # hidden: [B,C,H,W], flow_xy: [B,2,H,W] in pixel space + b, _, h, w = hidden.shape + base = _normalized_base_grid(h, w, hidden.device, hidden.dtype).unsqueeze(0).repeat(b, 1, 1, 1) + + fx = flow_xy[:, 0] / max((w - 1) / 2.0, 1.0) + fy = flow_xy[:, 1] / max((h - 1) / 2.0, 1.0) + flow = torch.stack((fx, fy), dim=-1) # [B,H,W,2] + + grid = base + flow + return F.grid_sample(hidden, grid, mode="bilinear", padding_mode="border", align_corners=True) + + +class ConvGRUCell(nn.Module): + def __init__(self, input_channels: int, hidden_channels: int, kernel_size: int = 3): + super().__init__() + padding = kernel_size // 2 + self.conv_zr = nn.Conv2d(input_channels + hidden_channels, hidden_channels * 2, kernel_size, padding=padding) + self.conv_n = nn.Conv2d(input_channels + hidden_channels, hidden_channels, kernel_size, padding=padding) + + def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor: + fused = torch.cat([x_t, h_prev], dim=1) + z, r = torch.chunk(self.conv_zr(fused), 2, dim=1) + z = torch.sigmoid(z) + r = torch.sigmoid(r) + + n = torch.tanh(self.conv_n(torch.cat([x_t, r * h_prev], dim=1))) + return (1.0 - z) * h_prev + z * n + + +class TrajGRUCell(nn.Module): + def __init__(self, input_channels: int, hidden_channels: int, kernel_size: int = 3): + super().__init__() + padding = kernel_size // 2 + self.flow_net = nn.Conv2d(input_channels + hidden_channels, 2, kernel_size=3, padding=1) + self.conv_zr = nn.Conv2d(input_channels + hidden_channels, hidden_channels * 2, kernel_size, padding=padding) + self.conv_n = nn.Conv2d(input_channels + hidden_channels, hidden_channels, kernel_size, padding=padding) + + def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor: + flow = self.flow_net(torch.cat([x_t, h_prev], dim=1)) + h_warp = _warp_hidden(h_prev, flow) + + fused = torch.cat([x_t, h_warp], dim=1) + z, r = torch.chunk(self.conv_zr(fused), 2, dim=1) + z = torch.sigmoid(z) + r = torch.sigmoid(r) + + n = torch.tanh(self.conv_n(torch.cat([x_t, r * h_warp], dim=1))) + return (1.0 - z) * h_warp + z * n + + +class TinyConvGRTrajGRU(nn.Module): + def __init__(self, in_channels: int = 1, enc_channels: int = 16, hidden_channels: int = 16, kernel_size: int = 3): + super().__init__() + self.hidden_channels = hidden_channels + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, enc_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(enc_channels, enc_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + self.convgru = ConvGRUCell(input_channels=enc_channels, hidden_channels=hidden_channels, kernel_size=kernel_size) + self.trajgru = TrajGRUCell(input_channels=hidden_channels, hidden_channels=hidden_channels, kernel_size=kernel_size) + + self.decoder = nn.Sequential( + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + b, _, _, h, w = x_seq.shape + device = x_seq.device + dtype = x_seq.dtype + + h_conv = torch.zeros((b, self.hidden_channels, h, w), device=device, dtype=dtype) + h_traj = torch.zeros((b, self.hidden_channels, h, w), device=device, dtype=dtype) + + for t in range(x_seq.shape[1]): + x_t = self.encoder(x_seq[:, t]) + h_conv = self.convgru(x_t, h_conv) + h_traj = self.trajgru(h_conv, h_traj) + + return self.decoder(h_traj) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_convgru_trajgru_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: ConvGRTrajGRUTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyConvGRTrajGRU( + in_channels=cfg.in_channels, + enc_channels=cfg.enc_channels, + hidden_channels=cfg.hidden_channels, + kernel_size=cfg.kernel_size, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("ConvGRU/TrajGRU Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: ConvGRTrajGRUTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "convgru_trajgru", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = ConvGRTrajGRUTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_convgru_trajgru_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "convgru_trajgru_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run ConvGRU/TrajGRU Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"convgru_trajgru_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] convgru_trajgru synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def convgru_trajgru_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "convgru_trajgru") + init_kwargs = filter_init_kwargs(TinyConvGRTrajGRU, {"in_channels": int(in_channels), **kwargs}) + model = TinyConvGRTrajGRU(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyConvGRTrajGRU", "convgru_trajgru_builder"] diff --git a/pyhazards/models/convlstm.py b/pyhazards/models/convlstm.py new file mode 100644 index 00000000..9b52dd84 --- /dev/null +++ b/pyhazards/models/convlstm.py @@ -0,0 +1,470 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class ConvLSTMTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + enc_channels: int = 16 + hidden_channels: int = 16 + num_layers: int = 2 + kernel_size: int = 3 + lr: float = 3e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class ConvLSTMCell(nn.Module): + def __init__(self, input_channels: int, hidden_channels: int, kernel_size: int = 3): + super().__init__() + padding = kernel_size // 2 + self.hidden_channels = hidden_channels + self.conv = nn.Conv2d( + input_channels + hidden_channels, + hidden_channels * 4, + kernel_size=kernel_size, + padding=padding, + ) + + def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + fused = torch.cat([x_t, h_prev], dim=1) + gates = self.conv(fused) + i, f, o, g = torch.chunk(gates, 4, dim=1) + i = torch.sigmoid(i) + f = torch.sigmoid(f) + o = torch.sigmoid(o) + g = torch.tanh(g) + + c = f * c_prev + i * g + h = o * torch.tanh(c) + return h, c + + +class TinyConvLSTM(nn.Module): + def __init__( + self, + in_channels: int = 1, + enc_channels: int = 16, + hidden_channels: int = 16, + num_layers: int = 2, + kernel_size: int = 3, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.num_layers = num_layers + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, enc_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(enc_channels, enc_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + cells: List[nn.Module] = [] + for i in range(num_layers): + in_ch = enc_channels if i == 0 else hidden_channels + cells.append(ConvLSTMCell(input_channels=in_ch, hidden_channels=hidden_channels, kernel_size=kernel_size)) + self.cells = nn.ModuleList(cells) + + self.decoder = nn.Sequential( + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + b, _, _, h, w = x_seq.shape + device = x_seq.device + + h_states = [ + torch.zeros((b, self.hidden_channels, h, w), device=device, dtype=x_seq.dtype) + for _ in range(self.num_layers) + ] + c_states = [ + torch.zeros((b, self.hidden_channels, h, w), device=device, dtype=x_seq.dtype) + for _ in range(self.num_layers) + ] + + for t in range(x_seq.shape[1]): + x_t = self.encoder(x_seq[:, t]) # [B,enc,H,W] + for i, cell in enumerate(self.cells): + h_i, c_i = cell(x_t, h_states[i], c_states[i]) + h_states[i], c_states[i] = h_i, c_i + x_t = h_i + + return self.decoder(h_states[-1]) + + +def _choose_device(device_text: str) -> torch.device: + normalized = str(device_text).strip().lower() + if normalized.startswith("cuda") and torch.cuda.is_available(): + return torch.device(device_text) + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_convlstm_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: ConvLSTMTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyConvLSTM( + in_channels=cfg.in_channels, + enc_channels=cfg.enc_channels, + hidden_channels=cfg.hidden_channels, + num_layers=cfg.num_layers, + kernel_size=cfg.kernel_size, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("ConvLSTM Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: ConvLSTMTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "convlstm", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = ConvLSTMTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_convlstm_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "convlstm_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run ConvLSTM Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"convlstm_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] convlstm synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def convlstm_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "convlstm") + init_kwargs = filter_init_kwargs(TinyConvLSTM, {"in_channels": int(in_channels), **kwargs}) + model = TinyConvLSTM(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyConvLSTM", "convlstm_builder"] diff --git a/pyhazards/models/deep_ensemble.py b/pyhazards/models/deep_ensemble.py new file mode 100644 index 00000000..fbe06a1a --- /dev/null +++ b/pyhazards/models/deep_ensemble.py @@ -0,0 +1,464 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.unet import ( + binary_ece, + make_synthetic_fire_maps, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .unet import ( + binary_ece, + make_synthetic_fire_maps, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class DeepEnsembleTrackOConfig: + in_channels: int = 1 + base_channels: int = 8 + ensemble_size: int = 5 + lr: float = 1e-3 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + +class TinyEnsembleMember(nn.Module): + def __init__(self, in_channels: int = 1, base_channels: int = 8): + super().__init__() + c1, c2 = base_channels, base_channels * 2 + self.enc1 = ConvBlock(in_channels, c1) + self.pool = nn.MaxPool2d(kernel_size=2) + self.enc2 = ConvBlock(c1, c2) + self.up = nn.ConvTranspose2d(c2, c1, kernel_size=2, stride=2) + self.dec = ConvBlock(c1 + c1, c1) + self.head = nn.Conv2d(c1, 1, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.enc1(x) + x2 = self.enc2(self.pool(x1)) + y = self.up(x2) + y = torch.cat([y, x1], dim=1) + y = self.dec(y) + return self.head(y) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def _predict_ensemble_probabilities(models: List[nn.Module], loader: DataLoader, device: torch.device) -> Tuple[np.ndarray, float]: + member_probs: List[np.ndarray] = [] + for model in models: + member_probs.append(_predict_probabilities(model, loader, device)) + if not member_probs: + return np.zeros((0,), dtype=np.float32), 0.0 + stacked = np.stack(member_probs, axis=0) # [M,N] + return np.mean(stacked, axis=0), float(np.mean(np.std(stacked, axis=0))) + + +def _train_one_epoch(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, criterion: nn.Module, device: torch.device) -> float: + model.train() + losses: List[float] = [] + for xb, yb in loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + losses.append(float(loss.item())) + return float(np.mean(losses)) if losses else float("nan") + + +def _eval_loss(model: nn.Module, loader: DataLoader, criterion: nn.Module, device: torch.device) -> float: + model.eval() + losses: List[float] = [] + with torch.no_grad(): + for xb, yb in loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + losses.append(float(loss.item())) + return float(np.mean(losses)) if losses else float("nan") + + +def train_deep_ensemble_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: DeepEnsembleTrackOConfig, +): + if x_train.ndim != 4 or x_val.ndim != 4: + raise ValueError("x_train and x_val must be 4D arrays [N,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + if cfg.ensemble_size < 1: + raise ValueError("ensemble_size must be >= 1") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + models: List[nn.Module] = [] + optimizers: List[torch.optim.Optimizer] = [] + for m in range(cfg.ensemble_size): + member_seed = cfg.seed + 1000 + m + torch.manual_seed(member_seed) + np.random.seed(member_seed) + member = TinyEnsembleMember(in_channels=cfg.in_channels, base_channels=cfg.base_channels).to(device) + opt = torch.optim.AdamW(member.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + models.append(member) + optimizers.append(opt) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_states: List[Dict[str, torch.Tensor]] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + train_member_losses: List[float] = [] + val_member_losses: List[float] = [] + + for model, opt in zip(models, optimizers): + tr = _train_one_epoch(model, train_loader, opt, criterion, device) + va = _eval_loss(model, val_loader, criterion, device) + train_member_losses.append(tr) + val_member_losses.append(va) + + tr_loss = float(np.mean(train_member_losses)) if train_member_losses else float("nan") + va_loss = float(np.mean(val_member_losses)) if val_member_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(np.mean([opt.param_groups[0]["lr"] for opt in optimizers])), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_states = [deepcopy(model.state_dict()) for model in models] + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_states is not None: + for model, state in zip(models, best_states): + model.load_state_dict(state) + + val_prob, mean_ensemble_std = _predict_ensemble_probabilities(models, val_loader, device=device) + val_prob = np.clip(val_prob, 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + "mean_ensemble_std": float(mean_ensemble_std), + } + + return models, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("Deep Ensemble Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: DeepEnsembleTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "deep_ensemble", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + "uncertainty": ["mean_ensemble_std"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "ensemble_size": cfg.ensemble_size, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 192, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_maps(n_samples=n_samples, image_size=image_size, seed=seed) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = DeepEnsembleTrackOConfig( + seed=seed, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + models, history, val_metrics, best_epoch, pos_weight = train_deep_ensemble_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob, test_std = _predict_ensemble_probabilities(models, test_loader, _choose_device(cfg.device)) + test_prob = np.clip(test_prob, 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + "mean_ensemble_std": float(test_std), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "members": [model.state_dict() for model in models], + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "deep_ensemble_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Deep Ensemble Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=192) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"deep_ensemble_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] deep_ensemble synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, require_task + + +class DeepEnsemble(nn.Module): + """Benchmark-facing deep ensemble that averages member logits.""" + + def __init__(self, in_channels: int = 1, base_channels: int = 8, ensemble_size: int = 5): + super().__init__() + if ensemble_size < 1: + raise ValueError('ensemble_size must be >= 1') + self.members = nn.ModuleList( + [TinyEnsembleMember(in_channels=in_channels, base_channels=base_channels) for _ in range(int(ensemble_size))] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + logits = torch.stack([member(x) for member in self.members], dim=0).mean(dim=0) + return logits + + +def deep_ensemble_builder( + task: str, + in_channels: int = 1, + out_dim: int = 1, + base_channels: int = 8, + ensemble_size: int = 5, + **kwargs: Any, +) -> nn.Module: + _ = kwargs + require_task(task, {"segmentation"}, "deep_ensemble") + model = DeepEnsemble(in_channels=int(in_channels), base_channels=int(base_channels), ensemble_size=int(ensemble_size)) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["DeepEnsemble", "TinyEnsembleMember", "deep_ensemble_builder"] diff --git a/pyhazards/models/deeplabv3p.py b/pyhazards/models/deeplabv3p.py new file mode 100644 index 00000000..0c04437a --- /dev/null +++ b/pyhazards/models/deeplabv3p.py @@ -0,0 +1,452 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.unet import ( + binary_ece, + make_synthetic_fire_maps, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .unet import ( + binary_ece, + make_synthetic_fire_maps, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class DeepLabV3PTrackOConfig: + in_channels: int = 1 + base_channels: int = 16 + lr: float = 3e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class ConvBNReLU(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, dilation: int = 1): + super().__init__() + padding = dilation * (kernel_size // 2) + self.block = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + +class TinyBackbone(nn.Module): + def __init__(self, in_channels: int = 1, base_channels: int = 16): + super().__init__() + c1, c2, c3 = base_channels, base_channels * 2, base_channels * 4 + self.low = nn.Sequential( + ConvBNReLU(in_channels, c1, kernel_size=3, stride=1), + ConvBNReLU(c1, c1, kernel_size=3, stride=1), + ) + self.mid = nn.Sequential( + ConvBNReLU(c1, c2, kernel_size=3, stride=2), + ConvBNReLU(c2, c2, kernel_size=3, stride=1), + ) + self.high = nn.Sequential( + ConvBNReLU(c2, c3, kernel_size=3, stride=2), + ConvBNReLU(c3, c3, kernel_size=3, stride=1), + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + low = self.low(x) + x = self.mid(low) + high = self.high(x) + return low, high + + +class ASPP(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.branch1 = ConvBNReLU(in_channels, out_channels, kernel_size=1, stride=1, dilation=1) + self.branch2 = ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=1, dilation=2) + self.branch3 = ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=1, dilation=4) + self.branch4 = ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=1, dilation=6) + self.project = ConvBNReLU(out_channels * 4, out_channels, kernel_size=1, stride=1, dilation=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b1 = self.branch1(x) + b2 = self.branch2(x) + b3 = self.branch3(x) + b4 = self.branch4(x) + return self.project(torch.cat([b1, b2, b3, b4], dim=1)) + + +class TinyDeepLabV3P(nn.Module): + def __init__(self, in_channels: int = 1, base_channels: int = 16): + super().__init__() + c1, c3 = base_channels, base_channels * 4 + self.backbone = TinyBackbone(in_channels=in_channels, base_channels=base_channels) + self.aspp = ASPP(in_channels=c3, out_channels=c1 * 2) + self.low_proj = ConvBNReLU(c1, c1, kernel_size=1, stride=1, dilation=1) + self.decoder = nn.Sequential( + ConvBNReLU(c1 * 3, c1 * 2, kernel_size=3, stride=1, dilation=1), + ConvBNReLU(c1 * 2, c1, kernel_size=3, stride=1, dilation=1), + nn.Conv2d(c1, 1, kernel_size=1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + low, high = self.backbone(x) + aspp = self.aspp(high) + aspp_up = F.interpolate(aspp, size=low.shape[-2:], mode="bilinear", align_corners=False) + low = self.low_proj(low) + logits = self.decoder(torch.cat([aspp_up, low], dim=1)) + return logits + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_deeplabv3p_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: DeepLabV3PTrackOConfig, +): + if x_train.ndim != 4 or x_val.ndim != 4: + raise ValueError("x_train and x_val must be 4D arrays [N,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyDeepLabV3P(in_channels=cfg.in_channels, base_channels=cfg.base_channels).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("DeepLabv3+ Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: DeepLabV3PTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "deeplabv3p", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 192, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_maps(n_samples=n_samples, image_size=image_size, seed=seed) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = DeepLabV3PTrackOConfig( + seed=seed, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_deeplabv3p_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "deeplabv3p_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run DeepLabv3+ Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=192) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"deeplabv3p_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] deeplabv3p synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def deeplabv3p_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "deeplabv3p") + init_kwargs = filter_init_kwargs(TinyDeepLabV3P, {"in_channels": int(in_channels), **kwargs}) + model = TinyDeepLabV3P(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyDeepLabV3P", "deeplabv3p_builder"] diff --git a/pyhazards/models/earthfarseer.py b/pyhazards/models/earthfarseer.py new file mode 100644 index 00000000..65b14fc3 --- /dev/null +++ b/pyhazards/models/earthfarseer.py @@ -0,0 +1,453 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class EarthFarseerTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + hidden_channels: int = 16 + num_heads: int = 4 + num_layers: int = 2 + lr: float = 3e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class TemporalBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, dropout: float = 0.0): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True) + self.norm2 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim * 2, dim), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = self.norm1(x) + y, _ = self.attn(y, y, y, need_weights=False) + x = x + y + x = x + self.ffn(self.norm2(x)) + return x + + +class TinyEarthFarseer(nn.Module): + def __init__(self, in_channels: int = 1, hidden_channels: int = 16, num_heads: int = 4, num_layers: int = 2): + super().__init__() + self.hidden_channels = hidden_channels + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + self.temporal_blocks = nn.ModuleList([TemporalBlock(hidden_channels, num_heads) for _ in range(num_layers)]) + self.far_blocks = nn.ModuleList([TemporalBlock(hidden_channels, num_heads) for _ in range(max(1, num_layers - 1))]) + + self.decoder = nn.Sequential( + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + b, t, c, h, w = x_seq.shape + + x = x_seq.reshape(b * t, c, h, w) + x = self.encoder(x) + d = x.shape[1] + + x = x.reshape(b, t, d, h, w) + x = x.permute(0, 3, 4, 1, 2).contiguous().reshape(b * h * w, t, d) # [BHW,T,D] + + x_main = x + for blk in self.temporal_blocks: + x_main = blk(x_main) + + # Far-seer branch emphasizes farther temporal gaps. + x_far = x[:, ::2, :] + for blk in self.far_blocks: + x_far = blk(x_far) + + x_last = x_main[:, -1, :] + x_far[:, -1, :] + x_last = x_last.reshape(b, h, w, d).permute(0, 3, 1, 2).contiguous() + return self.decoder(x_last) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_earthfarseer_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: EarthFarseerTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyEarthFarseer( + in_channels=cfg.in_channels, + hidden_channels=cfg.hidden_channels, + num_heads=cfg.num_heads, + num_layers=cfg.num_layers, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("EarthFarseer Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: EarthFarseerTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "earthfarseer", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = EarthFarseerTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_earthfarseer_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "earthfarseer_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run EarthFarseer Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"earthfarseer_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] earthfarseer synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def earthfarseer_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "earthfarseer") + init_kwargs = filter_init_kwargs(TinyEarthFarseer, {"in_channels": int(in_channels), **kwargs}) + model = TinyEarthFarseer(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyEarthFarseer", "earthfarseer_builder"] diff --git a/pyhazards/models/earthformer.py b/pyhazards/models/earthformer.py new file mode 100644 index 00000000..745b6e85 --- /dev/null +++ b/pyhazards/models/earthformer.py @@ -0,0 +1,448 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class EarthFormerTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + hidden_channels: int = 16 + num_heads: int = 4 + num_layers: int = 2 + lr: float = 1e-3 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class TemporalTransformerBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 2.0, dropout: float = 0.0): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True) + self.norm2 = nn.LayerNorm(dim) + hidden = int(dim * mlp_ratio) + self.ffn = nn.Sequential( + nn.Linear(dim, hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden, dim), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [BHW, T, D] + x_norm = self.norm1(x) + attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False) + x = x + attn_out + x = x + self.ffn(self.norm2(x)) + return x + + +class TinyEarthFormer(nn.Module): + def __init__(self, in_channels: int = 1, hidden_channels: int = 16, num_heads: int = 4, num_layers: int = 2): + super().__init__() + self.hidden_channels = hidden_channels + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + self.temporal_blocks = nn.ModuleList( + [TemporalTransformerBlock(dim=hidden_channels, num_heads=num_heads) for _ in range(num_layers)] + ) + + self.decoder = nn.Sequential( + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B, T, C, H, W] + b, t, c, h, w = x_seq.shape + + x = x_seq.reshape(b * t, c, h, w) + x = self.encoder(x) + d = x.shape[1] + + x = x.reshape(b, t, d, h, w) + x = x.permute(0, 3, 4, 1, 2).contiguous() # [B,H,W,T,D] + x = x.reshape(b * h * w, t, d) + + for block in self.temporal_blocks: + x = block(x) + + x = x[:, -1, :] # next-step representation per location + x = x.reshape(b, h, w, d).permute(0, 3, 1, 2).contiguous() # [B,D,H,W] + + return self.decoder(x) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_earthformer_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: EarthFormerTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyEarthFormer( + in_channels=cfg.in_channels, + hidden_channels=cfg.hidden_channels, + num_heads=cfg.num_heads, + num_layers=cfg.num_layers, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("Earthformer Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: EarthFormerTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "earthformer", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = EarthFormerTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_earthformer_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "earthformer_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Earthformer Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = Path(args.output_dir) if args.output_dir else base / f"earthformer_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] earthformer synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def earthformer_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "earthformer") + init_kwargs = filter_init_kwargs(TinyEarthFormer, {"in_channels": int(in_channels), **kwargs}) + model = TinyEarthFormer(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyEarthFormer", "earthformer_builder"] diff --git a/pyhazards/models/firecastnet.py b/pyhazards/models/firecastnet.py index 0812f9dd..74c27e2e 100644 --- a/pyhazards/models/firecastnet.py +++ b/pyhazards/models/firecastnet.py @@ -5,7 +5,7 @@ class FireCastNet(nn.Module): - """Compact encoder-decoder wildfire spread network.""" + """Compact encoder-decoder wildfire forecasting network.""" def __init__( self, diff --git a/pyhazards/models/firemm_ir.py b/pyhazards/models/firemm_ir.py new file mode 100644 index 00000000..3754a748 --- /dev/null +++ b/pyhazards/models/firemm_ir.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ModalityEncoder(nn.Module): + def __init__(self, in_channels: int, hidden_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(in_channels, hidden_dim // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class ClassAwareMemory(nn.Module): + """Small memory bank inspired by FireMM-IR's class-aware memory module.""" + + def __init__(self, hidden_dim: int, num_memory_slots: int = 3): + super().__init__() + self.hidden_dim = int(hidden_dim) + self.num_memory_slots = int(num_memory_slots) + self.memory = nn.Parameter(torch.randn(self.num_memory_slots, self.hidden_dim) * 0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch, channels, height, width = x.shape + tokens = x.flatten(2).transpose(1, 2) # (B, HW, C) + scores = torch.matmul(tokens, self.memory.t()) / max(1.0, self.hidden_dim ** 0.5) + weights = torch.softmax(scores, dim=-1) + retrieved = torch.matmul(weights, self.memory).transpose(1, 2).reshape(batch, channels, height, width) + return x + retrieved + + +class FireMMIR(nn.Module): + """Dual-modality wildfire scene model inspired by FireMM-IR.""" + + def __init__( + self, + in_channels: int = 6, + out_dim: int = 1, + hidden_dim: int = 64, + instruction_dim: int = 16, + num_memory_slots: int = 3, + num_heads: int = 4, + dropout: float = 0.1, + ): + super().__init__() + if in_channels < 2 or in_channels % 2 != 0: + raise ValueError(f"in_channels must be an even number >= 2, got {in_channels}") + if out_dim <= 0: + raise ValueError(f"out_dim must be positive, got {out_dim}") + if hidden_dim % num_heads != 0: + raise ValueError(f"hidden_dim={hidden_dim} must be divisible by num_heads={num_heads}") + + self.in_channels = int(in_channels) + self.hidden_dim = int(hidden_dim) + self.instruction_dim = int(instruction_dim) + self.optical_channels = self.in_channels // 2 + self.infrared_channels = self.in_channels - self.optical_channels + + self.optical_encoder = ModalityEncoder(self.optical_channels, hidden_dim) + self.infrared_encoder = ModalityEncoder(self.infrared_channels, hidden_dim) + self.fusion_gate = nn.Sequential( + nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=1), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1), + nn.Sigmoid(), + ) + self.memory = ClassAwareMemory(hidden_dim=hidden_dim, num_memory_slots=num_memory_slots) + + self.instruction_proj = nn.Linear(self.instruction_dim, hidden_dim) + self.segmentation_token = nn.Parameter(torch.randn(1, hidden_dim) * 0.02) + self.token_attn = nn.MultiheadAttention( + embed_dim=hidden_dim, + num_heads=num_heads, + dropout=dropout, + batch_first=True, + ) + self.ffn = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, hidden_dim * 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 2, hidden_dim), + ) + self.decoder = nn.Sequential( + nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=3, padding=1), + nn.GELU(), + ) + self.head = nn.Conv2d(hidden_dim // 2, out_dim, kernel_size=1) + + def _unpack_inputs(self, inputs: torch.Tensor | Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor | None]: + if isinstance(inputs, dict): + x = inputs.get("x") + instruction = inputs.get("instruction_context") + else: + x = inputs + instruction = None + + if not isinstance(x, torch.Tensor): + raise ValueError("FireMMIR expects a tensor input or a dict containing key 'x'.") + if x.ndim != 4: + raise ValueError(f"FireMMIR expects input shape (B, C, H, W), got {tuple(x.shape)}") + if x.size(1) != self.in_channels: + raise ValueError(f"FireMMIR expected in_channels={self.in_channels}, got {x.size(1)}") + return x, instruction + + def _coerce_instruction(self, instruction: torch.Tensor | None, batch: int, device: torch.device) -> torch.Tensor: + if instruction is None: + return torch.zeros(batch, self.instruction_dim, device=device) + if instruction.ndim != 2 or instruction.size(0) != batch: + raise ValueError(f"instruction_context must have shape (B,D), got {tuple(instruction.shape)}") + if instruction.size(1) == self.instruction_dim: + return instruction.to(device=device, dtype=torch.float32) + if instruction.size(1) > self.instruction_dim: + return instruction[:, : self.instruction_dim].to(device=device, dtype=torch.float32) + pad = torch.zeros(batch, self.instruction_dim - instruction.size(1), device=device) + return torch.cat([instruction.to(device=device, dtype=torch.float32), pad], dim=1) + + def forward(self, inputs: torch.Tensor | Dict[str, Any]) -> torch.Tensor: + x, instruction = self._unpack_inputs(inputs) + batch = x.size(0) + device = x.device + + optical = x[:, : self.optical_channels] + infrared = x[:, self.optical_channels :] + optical_feat = self.optical_encoder(optical) + infrared_feat = self.infrared_encoder(infrared) + gate = self.fusion_gate(torch.cat([optical_feat, infrared_feat], dim=1)) + fused = optical_feat + gate * infrared_feat + fused = self.memory(fused) + + visual_tokens = fused.flatten(2).transpose(1, 2) + instruction_token = self.instruction_proj(self._coerce_instruction(instruction, batch, device)).unsqueeze(1) + seg_token = self.segmentation_token.unsqueeze(0).expand(batch, -1, -1) + query_tokens = torch.cat([seg_token, instruction_token], dim=1) + attn_out, _ = self.token_attn(query_tokens, visual_tokens, visual_tokens, need_weights=False) + query_tokens = attn_out + self.ffn(attn_out) + global_token = query_tokens.mean(dim=1) + + context_map = global_token.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, fused.size(-2), fused.size(-1)) + decoded = self.decoder(torch.cat([fused, context_map], dim=1)) + return self.head(decoded) + + +def firemm_ir_builder( + task: str, + in_channels: int = 6, + out_dim: int = 1, + hidden_dim: int = 64, + instruction_dim: int = 16, + num_memory_slots: int = 3, + num_heads: int = 4, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError(f"firemm_ir is segmentation-only in PyHazards, got task={task!r}.") + return FireMMIR( + in_channels=in_channels, + out_dim=out_dim, + hidden_dim=hidden_dim, + instruction_dim=instruction_dim, + num_memory_slots=num_memory_slots, + num_heads=num_heads, + dropout=dropout, + ) + + +__all__ = ["FireMMIR", "firemm_ir_builder"] diff --git a/pyhazards/models/firepred.py b/pyhazards/models/firepred.py new file mode 100644 index 00000000..0cf67033 --- /dev/null +++ b/pyhazards/models/firepred.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import torch +import torch.nn as nn + + +class FirePred(nn.Module): + """Hybrid multi-temporal CNN wildfire spread predictor inspired by FirePred.""" + + def __init__( + self, + history: int = 5, + in_channels: int = 8, + hidden_dim: int = 32, + out_channels: int = 1, + dropout: float = 0.1, + ): + super().__init__() + if history <= 0: + raise ValueError(f"history must be positive, got {history}") + if in_channels <= 0: + raise ValueError(f"in_channels must be positive, got {in_channels}") + if hidden_dim <= 0: + raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") + if out_channels <= 0: + raise ValueError(f"out_channels must be positive, got {out_channels}") + if not 0.0 <= dropout < 1.0: + raise ValueError(f"dropout must be in [0, 1), got {dropout}") + + self.history = int(history) + self.in_channels = int(in_channels) + + self.recent_branch = nn.Sequential( + nn.Conv3d(in_channels, hidden_dim, kernel_size=(3, 3, 3), padding=1), + nn.GELU(), + nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(3, 3, 3), padding=1), + nn.GELU(), + ) + self.daily_branch = nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + ) + self.snapshot_branch = nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=2, dilation=2), + nn.GELU(), + ) + self.fusion = nn.Sequential( + nn.Conv2d(hidden_dim * 3, hidden_dim * 2, kernel_size=3, padding=1), + nn.GELU(), + nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(), + nn.Conv2d(hidden_dim * 2, out_channels, kernel_size=1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.ndim != 5: + raise ValueError( + "FirePred expects input shape (batch, history, channels, height, width), " + f"got {tuple(x.shape)}." + ) + if x.size(1) != self.history: + raise ValueError(f"FirePred expected history={self.history}, got {x.size(1)}.") + if x.size(2) != self.in_channels: + raise ValueError(f"FirePred expected in_channels={self.in_channels}, got {x.size(2)}.") + + x_3d = x.permute(0, 2, 1, 3, 4) + recent = torch.mean(self.recent_branch(x_3d), dim=2) + daily = self.daily_branch(torch.mean(x, dim=1)) + snapshot = self.snapshot_branch(x[:, -1]) + fused = torch.cat([recent, daily, snapshot], dim=1) + return self.fusion(fused) + + +def firepred_builder( + task: str, + history: int = 5, + in_channels: int = 8, + hidden_dim: int = 32, + out_channels: int = 1, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() not in {"segmentation", "regression"}: + raise ValueError(f"firepred supports task='segmentation' or 'regression', got {task!r}.") + return FirePred( + history=history, + in_channels=in_channels, + hidden_dim=hidden_dim, + out_channels=out_channels, + dropout=dropout, + ) + + +__all__ = ["FirePred", "firepred_builder"] diff --git a/pyhazards/models/forefire.py b/pyhazards/models/forefire.py index b9b6f565..d117bc77 100644 --- a/pyhazards/models/forefire.py +++ b/pyhazards/models/forefire.py @@ -21,6 +21,7 @@ def __init__( raise ValueError(f"ForeFireAdapter only supports out_channels=1, got {out_channels}") if diffusion_steps <= 0: raise ValueError(f"diffusion_steps must be positive, got {diffusion_steps}") + self.in_channels = int(in_channels) self.diffusion_steps = int(diffusion_steps) kernel = torch.tensor( @@ -37,12 +38,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) if x.size(1) != self.in_channels: raise ValueError(f"ForeFireAdapter expected in_channels={self.in_channels}, got {x.size(1)}.") + state = torch.sigmoid(x[:, :1]) fuel = torch.sigmoid(x[:, 1:2]) wind = torch.tanh(x[:, 2:3]).abs() for _ in range(self.diffusion_steps): neighborhood = F.conv2d(state, self.spread_kernel, padding=1) - state = torch.clamp(0.45 * state + 0.4 * neighborhood + 0.1 * fuel + 0.05 * wind, 0.0, 1.0) + state = torch.clamp(0.45 * state + 0.40 * neighborhood + 0.10 * fuel + 0.05 * wind, 0.0, 1.0) return state diff --git a/pyhazards/models/gemini_25_pro_wildfire_prompted.py b/pyhazards/models/gemini_25_pro_wildfire_prompted.py new file mode 100644 index 00000000..6fcac692 --- /dev/null +++ b/pyhazards/models/gemini_25_pro_wildfire_prompted.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import torch.nn as nn + +from .qwen25_vl_wildfire_prompted import Qwen25VLWildfirePrompted + + +class Gemini25ProWildfirePrompted(Qwen25VLWildfirePrompted): + """Benchmark-facing wildfire VLM baseline inspired by Gemini 2.5 Pro.""" + + +def gemini_25_pro_wildfire_prompted_builder( + task: str, + in_channels: int = 6, + out_dim: int = 1, + hidden_dim: int = 96, + prompt_dim: int = 32, + num_prompt_tokens: int = 6, + num_heads: int = 8, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError( + f"gemini_25_pro_wildfire_prompted is segmentation-only in PyHazards, got task={task!r}." + ) + return Gemini25ProWildfirePrompted( + in_channels=in_channels, + out_dim=out_dim, + hidden_dim=hidden_dim, + prompt_dim=prompt_dim, + num_prompt_tokens=num_prompt_tokens, + num_heads=num_heads, + dropout=dropout, + ) + + +__all__ = ["Gemini25ProWildfirePrompted", "gemini_25_pro_wildfire_prompted_builder"] diff --git a/pyhazards/models/internvl3_wildfire_prompted.py b/pyhazards/models/internvl3_wildfire_prompted.py new file mode 100644 index 00000000..cc038a9c --- /dev/null +++ b/pyhazards/models/internvl3_wildfire_prompted.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import torch.nn as nn + +from .qwen25_vl_wildfire_prompted import Qwen25VLWildfirePrompted + + +class InternVL3WildfirePrompted(Qwen25VLWildfirePrompted): + """Benchmark-facing wildfire VLM baseline inspired by InternVL3.""" + + +def internvl3_wildfire_prompted_builder( + task: str, + in_channels: int = 6, + out_dim: int = 1, + hidden_dim: int = 96, + prompt_dim: int = 32, + num_prompt_tokens: int = 5, + num_heads: int = 6, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError( + f"internvl3_wildfire_prompted is segmentation-only in PyHazards, got task={task!r}." + ) + return InternVL3WildfirePrompted( + in_channels=in_channels, + out_dim=out_dim, + hidden_dim=hidden_dim, + prompt_dim=prompt_dim, + num_prompt_tokens=num_prompt_tokens, + num_heads=num_heads, + dropout=dropout, + ) + + +__all__ = ["InternVL3WildfirePrompted", "internvl3_wildfire_prompted_builder"] diff --git a/pyhazards/models/lightgbm.py b/pyhazards/models/lightgbm.py new file mode 100644 index 00000000..d1ed541a --- /dev/null +++ b/pyhazards/models/lightgbm.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import Any, Optional + +import numpy as np +import torch.nn as nn + +from ._wildfire_benchmark_utils import EstimatorPort, filter_init_kwargs, require_task + + +class LightGBMModel(EstimatorPort): + """A boosted-tree wildfire occurrence baseline using LightGBM binary classification.""" + + def __init__(self, num_leaves: int = 63, learning_rate: float = 0.05, feature_fraction: float = 0.8, bagging_fraction: float = 0.8, num_boost_round: int = 800): + super().__init__() + self.params = { + "objective": "binary", + "metric": "binary_logloss", + "num_leaves": int(num_leaves), + "learning_rate": float(learning_rate), + "feature_fraction": float(feature_fraction), + "bagging_fraction": float(bagging_fraction), + "verbose": -1, + } + self.num_boost_round = int(num_boost_round) + self.booster = None + + def _fit_numpy( + self, + x_train: np.ndarray, + y_train: np.ndarray, + x_val: Optional[np.ndarray], + y_val: Optional[np.ndarray], + ) -> None: + import lightgbm as lgb + + dtrain = lgb.Dataset(x_train, label=y_train) + valid_sets = [dtrain] + valid_names = ["train"] + if x_val is not None and y_val is not None: + dval = lgb.Dataset(x_val, label=y_val, reference=dtrain) + valid_sets.append(dval) + valid_names.append("val") + self.booster = lgb.train( + params=self.params, + train_set=dtrain, + num_boost_round=self.num_boost_round, + valid_sets=valid_sets, + valid_names=valid_names, + callbacks=[lgb.log_evaluation(period=0)], + ) + + def _predict_positive_proba(self, x: np.ndarray) -> np.ndarray: + if self.booster is None: + raise RuntimeError("LightGBM booster is not fitted.") + return np.asarray(self.booster.predict(x), dtype=np.float32) + + +def lightgbm_builder(task: str, **kwargs: Any) -> nn.Module: + require_task(task, {"classification"}, "lightgbm") + build_kwargs = filter_init_kwargs(LightGBMModel, kwargs) + return LightGBMModel(**build_kwargs) + + +__all__ = ["LightGBMModel", "lightgbm_builder"] diff --git a/pyhazards/models/llama4_wildfire_prompted.py b/pyhazards/models/llama4_wildfire_prompted.py new file mode 100644 index 00000000..858d173c --- /dev/null +++ b/pyhazards/models/llama4_wildfire_prompted.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import torch.nn as nn + +from .qwen25_vl_wildfire_prompted import Qwen25VLWildfirePrompted + + +class Llama4WildfirePrompted(Qwen25VLWildfirePrompted): + """Benchmark-facing wildfire multimodal baseline inspired by Meta Llama 4.""" + + +def llama4_wildfire_prompted_builder( + task: str, + in_channels: int = 6, + out_dim: int = 1, + hidden_dim: int = 80, + prompt_dim: int = 32, + num_prompt_tokens: int = 4, + num_heads: int = 8, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError( + f"llama4_wildfire_prompted is segmentation-only in PyHazards, got task={task!r}." + ) + return Llama4WildfirePrompted( + in_channels=in_channels, + out_dim=out_dim, + hidden_dim=hidden_dim, + prompt_dim=prompt_dim, + num_prompt_tokens=num_prompt_tokens, + num_heads=num_heads, + dropout=dropout, + ) + + +__all__ = ["Llama4WildfirePrompted", "llama4_wildfire_prompted_builder"] diff --git a/pyhazards/models/logistic_regression.py b/pyhazards/models/logistic_regression.py new file mode 100644 index 00000000..83bba479 --- /dev/null +++ b/pyhazards/models/logistic_regression.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import Any, Optional + +import numpy as np +import torch.nn as nn + +from ._wildfire_benchmark_utils import EstimatorPort, filter_init_kwargs, require_task + + +class LogisticRegressionModel(EstimatorPort): + """A classical tabular binary-classification baseline for wildfire occurrence probability.""" + + def __init__(self, solver: str = "lbfgs", max_iter: int = 500, class_weight: Any = "balanced"): + super().__init__() + from sklearn.linear_model import LogisticRegression + + self.estimator = LogisticRegression( + solver=solver, + max_iter=int(max_iter), + class_weight=class_weight, + ) + + def _fit_numpy( + self, + x_train: np.ndarray, + y_train: np.ndarray, + x_val: Optional[np.ndarray], + y_val: Optional[np.ndarray], + ) -> None: + _ = x_val, y_val + self.estimator.fit(x_train, y_train) + + def _predict_positive_proba(self, x: np.ndarray) -> np.ndarray: + return self.estimator.predict_proba(x)[:, 1] + + +def logistic_regression_builder(task: str, **kwargs: Any) -> nn.Module: + require_task(task, {"classification"}, "logistic_regression") + build_kwargs = filter_init_kwargs(LogisticRegressionModel, kwargs) + return LogisticRegressionModel(**build_kwargs) + + +__all__ = ["LogisticRegressionModel", "logistic_regression_builder"] diff --git a/pyhazards/models/mau.py b/pyhazards/models/mau.py new file mode 100644 index 00000000..fee73423 --- /dev/null +++ b/pyhazards/models/mau.py @@ -0,0 +1,512 @@ +from __future__ import annotations + +import argparse +import csv +import json +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + + +@dataclass +class MAUTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + hidden_channels: int = 12 + lr: float = 1e-3 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +def binary_ece(y_true: np.ndarray, y_prob: np.ndarray, n_bins: int = 15) -> float: + bins = np.linspace(0.0, 1.0, n_bins + 1) + ece = 0.0 + n = float(len(y_true)) + for i in range(n_bins): + lo, hi = bins[i], bins[i + 1] + if i == n_bins - 1: + mask = (y_prob >= lo) & (y_prob <= hi) + else: + mask = (y_prob >= lo) & (y_prob < hi) + if not np.any(mask): + continue + acc = float(np.mean(y_true[mask])) + conf = float(np.mean(y_prob[mask])) + ece += (float(np.sum(mask)) / n) * abs(acc - conf) + return float(ece) + + +def normalized_consistency_score(mean_day_to_day_change: float) -> float: + return float(np.clip(1.0 - float(mean_day_to_day_change), 0.0, 1.0)) + + +class MAUCell(nn.Module): + def __init__(self, hidden_channels: int): + super().__init__() + in_channels = hidden_channels * 3 + self.conv_gate = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) + self.conv_cand = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) + + def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor, m_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + fused = torch.cat([x_t, h_prev, m_prev], dim=1) + gate = torch.sigmoid(self.conv_gate(fused)) + candidate = torch.tanh(self.conv_cand(fused)) + + h = gate * h_prev + (1.0 - gate) * candidate + m = 0.5 * m_prev + 0.5 * h + return h, m + + +class TinyMAU(nn.Module): + def __init__(self, in_channels: int = 1, hidden_channels: int = 12): + super().__init__() + self.hidden_channels = hidden_channels + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + self.cell = MAUCell(hidden_channels=hidden_channels) + self.decoder = nn.Sequential( + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B, T, C, H, W] + b, _, _, h, w = x_seq.shape + device = x_seq.device + + h_state = torch.zeros((b, self.hidden_channels, h, w), device=device) + m_state = torch.zeros((b, self.hidden_channels, h, w), device=device) + + for t in range(x_seq.shape[1]): + x_t = self.encoder(x_seq[:, t]) + h_state, m_state = self.cell(x_t, h_state, m_state) + + return self.decoder(h_state) + + +def make_synthetic_fire_sequences( + n_samples: int, + seq_len: int, + image_size: int, + seed: int, +) -> Tuple[np.ndarray, np.ndarray]: + rng = np.random.default_rng(seed) + yy, xx = np.meshgrid(np.arange(image_size), np.arange(image_size), indexing="ij") + + x = np.zeros((n_samples, seq_len, 1, image_size, image_size), dtype=np.float32) + y = np.zeros((n_samples, 1, image_size, image_size), dtype=np.float32) + + for i in range(n_samples): + n_sources = int(rng.integers(1, 4)) + + sources = [] + for _ in range(n_sources): + sources.append( + { + "cx0": float(rng.uniform(0, image_size - 1)), + "cy0": float(rng.uniform(0, image_size - 1)), + "vx": float(rng.uniform(-1.2, 1.2)), + "vy": float(rng.uniform(-1.2, 1.2)), + "sigma": float(rng.uniform(1.8, 4.2)), + "amp": float(rng.uniform(0.8, 2.2)), + } + ) + + terrain = (yy / max(1, image_size - 1)) * rng.uniform(-0.15, 0.15) + wind = (xx / max(1, image_size - 1)) * rng.uniform(-0.25, 0.25) + + last_field = None + for t in range(seq_len + 1): + field = rng.normal(0.0, 0.12, size=(image_size, image_size)) + for s in sources: + cx = float(np.clip(s["cx0"] + s["vx"] * t, 0.0, image_size - 1)) + cy = float(np.clip(s["cy0"] + s["vy"] * t, 0.0, image_size - 1)) + sigma = s["sigma"] + amp = s["amp"] + dist2 = (xx - cx) ** 2 + (yy - cy) ** 2 + field += amp * np.exp(-dist2 / (2.0 * sigma * sigma)) + + signal = field + terrain + wind + rng.normal(0.0, 0.08, size=(image_size, image_size)) + if t < seq_len: + x[i, t, 0] = signal.astype(np.float32) + else: + last_field = field + + if last_field is None: + raise RuntimeError("Synthetic generation failed to produce final frame") + + threshold = float(np.quantile(last_field, 0.90)) + y[i, 0] = (last_field > threshold).astype(np.float32) + + x_mean = float(np.mean(x)) + x_std = float(np.std(x) + 1e-6) + x = (x - x_mean) / x_std + return x, y + + +def split_train_val_test( + x: np.ndarray, + y: np.ndarray, + seed: int, + train_ratio: float = 0.7, + val_ratio: float = 0.15, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + n = x.shape[0] + rng = np.random.default_rng(seed) + idx = rng.permutation(n) + + n_train = max(1, int(n * train_ratio)) + n_val = max(1, int(n * val_ratio)) + n_train = min(n_train, n - 2) + n_val = min(n_val, n - n_train - 1) + + train_idx = idx[:n_train] + val_idx = idx[n_train : n_train + n_val] + test_idx = idx[n_train + n_val :] + + return ( + x[train_idx], + y[train_idx], + x[val_idx], + y[val_idx], + x[test_idx], + y[test_idx], + ) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_mau_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: MAUTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyMAU(in_channels=cfg.in_channels, hidden_channels=cfg.hidden_channels).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("MAU Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: MAUTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "mau", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = MAUTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_mau_track_o(x_train, y_train, x_val, y_val, cfg) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "mau_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run MAU Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = Path(args.output_dir) if args.output_dir else base / f"mau_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] mau synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def mau_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "mau") + init_kwargs = filter_init_kwargs(TinyMAU, {"in_channels": int(in_channels), **kwargs}) + model = TinyMAU(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyMAU", "mau_builder"] diff --git a/pyhazards/models/modis_active_fire_c61.py b/pyhazards/models/modis_active_fire_c61.py new file mode 100644 index 00000000..d9bfa23e --- /dev/null +++ b/pyhazards/models/modis_active_fire_c61.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MODISActiveFireC61(nn.Module): + """Algorithm-inspired MODIS Collection 6.1 active-fire detector with learnable calibration.""" + + def __init__( + self, + in_channels: int = 5, + hidden_dim: int = 24, + out_dim: int = 1, + context_kernel: int = 9, + dropout: float = 0.1, + ): + super().__init__() + if in_channels < 5: + raise ValueError( + "MODISActiveFireC61 expects at least 5 channels: " + "mid_ir, long_ir, frp_proxy, cloud_free, dryness." + ) + if hidden_dim <= 0: + raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") + if out_dim <= 0: + raise ValueError(f"out_dim must be positive, got {out_dim}") + if context_kernel <= 1 or context_kernel % 2 == 0: + raise ValueError(f"context_kernel must be an odd integer > 1, got {context_kernel}") + if not 0.0 <= dropout < 1.0: + raise ValueError(f"dropout must be in [0,1), got {dropout}") + + self.in_channels = int(in_channels) + self.context_pool = nn.AvgPool2d(kernel_size=context_kernel, stride=1, padding=context_kernel // 2) + + evidence_channels = self.in_channels + 5 + self.evidence_encoder = nn.Sequential( + nn.Conv2d(evidence_channels, hidden_dim, kernel_size=1), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + ) + self.calibration_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(), + nn.Conv2d(hidden_dim, out_dim, kernel_size=1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.ndim != 4: + raise ValueError( + "MODISActiveFireC61 expects input shape (batch, channels, height, width), " + f"got {tuple(x.shape)}." + ) + if x.size(1) < 5: + raise ValueError(f"MODISActiveFireC61 expected at least 5 channels, got {x.size(1)}.") + + x = x[:, : self.in_channels] + mid_ir = x[:, 0:1] + long_ir = x[:, 1:2] + frp_proxy = x[:, 2:3] + cloud_free = x[:, 3:4] + dryness = x[:, 4:5] + + local_background = self.context_pool(mid_ir) + thermal_excess = mid_ir - local_background + split_window = mid_ir - long_ir + fire_signal = F.relu(thermal_excess) + 0.4 * F.relu(split_window) + contextual_ratio = fire_signal / (torch.abs(local_background) + 1.0) + confidence_gate = torch.sigmoid(cloud_free) * torch.sigmoid(dryness) + + evidence = torch.cat( + [ + x, + thermal_excess, + split_window, + fire_signal, + contextual_ratio, + confidence_gate + frp_proxy, + ], + dim=1, + ) + encoded = self.evidence_encoder(evidence) + return self.calibration_head(encoded) + + +def modis_active_fire_c61_builder( + task: str, + in_channels: int = 5, + hidden_dim: int = 24, + out_dim: int = 1, + context_kernel: int = 9, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError( + f"modis_active_fire_c61 is segmentation-only in PyHazards, got task={task!r}." + ) + return MODISActiveFireC61( + in_channels=in_channels, + hidden_dim=hidden_dim, + out_dim=out_dim, + context_kernel=context_kernel, + dropout=dropout, + ) + + +__all__ = ["MODISActiveFireC61", "modis_active_fire_c61_builder"] diff --git a/pyhazards/models/predrnn_v2.py b/pyhazards/models/predrnn_v2.py new file mode 100644 index 00000000..1726c8b9 --- /dev/null +++ b/pyhazards/models/predrnn_v2.py @@ -0,0 +1,452 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class PredRNNv2TrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + hidden_channels: int = 12 + lr: float = 1e-3 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class SpatioTemporalLSTMCell(nn.Module): + def __init__(self, in_channels: int, hidden_channels: int): + super().__init__() + self.hidden_channels = hidden_channels + + self.conv_x = nn.Conv2d(in_channels, hidden_channels * 7, kernel_size=3, padding=1) + self.conv_h = nn.Conv2d(hidden_channels, hidden_channels * 4, kernel_size=3, padding=1) + self.conv_m = nn.Conv2d(hidden_channels, hidden_channels * 3, kernel_size=3, padding=1) + + self.conv_o = nn.Conv2d(hidden_channels * 2, hidden_channels, kernel_size=1) + self.conv_last = nn.Conv2d(hidden_channels * 2, hidden_channels, kernel_size=1) + + def forward( + self, + x_t: torch.Tensor, + h_prev: torch.Tensor, + c_prev: torch.Tensor, + m_prev: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x_concat = self.conv_x(x_t) + h_concat = self.conv_h(h_prev) + m_concat = self.conv_m(m_prev) + + x_i, x_f, x_g, x_i_m, x_f_m, x_g_m, x_o = torch.split(x_concat, self.hidden_channels, dim=1) + h_i, h_f, h_g, h_o = torch.split(h_concat, self.hidden_channels, dim=1) + m_i, m_f, m_g = torch.split(m_concat, self.hidden_channels, dim=1) + + i_t = torch.sigmoid(x_i + h_i) + f_t = torch.sigmoid(x_f + h_f + 1.0) + g_t = torch.tanh(x_g + h_g) + c_t = f_t * c_prev + i_t * g_t + + i_t_m = torch.sigmoid(x_i_m + m_i) + f_t_m = torch.sigmoid(x_f_m + m_f + 1.0) + g_t_m = torch.tanh(x_g_m + m_g) + m_t = f_t_m * m_prev + i_t_m * g_t_m + + mem = torch.cat([c_t, m_t], dim=1) + o_t = torch.sigmoid(x_o + h_o + self.conv_o(mem)) + h_t = o_t * torch.tanh(self.conv_last(mem)) + + return h_t, c_t, m_t + + +class TinyPredRNNv2(nn.Module): + def __init__(self, in_channels: int = 1, hidden_channels: int = 12): + super().__init__() + self.hidden_channels = hidden_channels + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + self.cell = SpatioTemporalLSTMCell(in_channels=hidden_channels, hidden_channels=hidden_channels) + self.decoder = nn.Sequential( + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B, T, C, H, W] + b, _, _, h, w = x_seq.shape + device = x_seq.device + + h_state = torch.zeros((b, self.hidden_channels, h, w), device=device) + c_state = torch.zeros((b, self.hidden_channels, h, w), device=device) + m_state = torch.zeros((b, self.hidden_channels, h, w), device=device) + + for t in range(x_seq.shape[1]): + x_t = self.encoder(x_seq[:, t]) + h_state, c_state, m_state = self.cell(x_t, h_state, c_state, m_state) + + return self.decoder(h_state) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_predrnn_v2_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: PredRNNv2TrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyPredRNNv2(in_channels=cfg.in_channels, hidden_channels=cfg.hidden_channels).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("PredRNN-v2 Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: PredRNNv2TrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "predrnn_v2", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = PredRNNv2TrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_predrnn_v2_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "predrnn_v2_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run PredRNN-v2 Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = Path(args.output_dir) if args.output_dir else base / f"predrnn_v2_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] predrnn_v2 synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def predrnn_v2_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "predrnn_v2") + init_kwargs = filter_init_kwargs(TinyPredRNNv2, {"in_channels": int(in_channels), **kwargs}) + model = TinyPredRNNv2(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyPredRNNv2", "predrnn_v2_builder"] diff --git a/pyhazards/models/prithvi_burnscars.py b/pyhazards/models/prithvi_burnscars.py new file mode 100644 index 00000000..67ce98ab --- /dev/null +++ b/pyhazards/models/prithvi_burnscars.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import Any, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .prithvi_eo_2_tl import PrithviEOBackbone + + +class PrithviBurnScars(nn.Module): + """Burn-scar segmentation model built on a Prithvi-EO-style temporal backbone.""" + + def __init__( + self, + image_size: int = 32, + in_channels: int = 6, + out_dim: int = 1, + patch_size: int = 4, + embed_dim: int = 128, + depth: int = 4, + num_heads: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + time_dim: int = 1, + location_dim: int = 2, + decoder_channels: int = 64, + ): + super().__init__() + if out_dim <= 0: + raise ValueError(f"out_dim must be positive, got {out_dim}") + self.in_channels = int(in_channels) + self.backbone = PrithviEOBackbone( + image_size=image_size, + in_channels=in_channels, + patch_size=patch_size, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + time_dim=time_dim, + location_dim=location_dim, + ) + self.skip = nn.Sequential( + nn.Conv2d(in_channels, decoder_channels, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(decoder_channels, decoder_channels, kernel_size=3, padding=1), + nn.GELU(), + ) + self.decoder = nn.Sequential( + nn.Conv2d(embed_dim + decoder_channels, decoder_channels, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(decoder_channels, decoder_channels // 2, kernel_size=3, padding=1), + nn.GELU(), + ) + self.head = nn.Conv2d(decoder_channels // 2, out_dim, kernel_size=1) + + def _extract_x(self, inputs: torch.Tensor | Dict[str, Any]) -> torch.Tensor: + if isinstance(inputs, dict): + x = inputs.get("x") + else: + x = inputs + if not isinstance(x, torch.Tensor): + raise ValueError("PrithviBurnScars expects a tensor input or a dict containing key 'x'.") + if x.ndim != 5: + raise ValueError(f"PrithviBurnScars expects input shape (B,T,C,H,W), got {tuple(x.shape)}") + if x.size(2) != self.in_channels: + raise ValueError(f"PrithviBurnScars expected in_channels={self.in_channels}, got {x.size(2)}") + return x + + def forward(self, inputs: torch.Tensor | Dict[str, Any]) -> torch.Tensor: + x = self._extract_x(inputs) + features = self.backbone(inputs) + skip = self.skip(x.mean(dim=1)) + up = F.interpolate(features, size=skip.shape[-2:], mode="bilinear", align_corners=False) + fused = torch.cat([up, skip], dim=1) + logits = self.head(self.decoder(fused)) + return logits + + + +def prithvi_burnscars_builder( + task: str, + image_size: int = 32, + in_channels: int = 6, + out_dim: int = 1, + patch_size: int = 4, + embed_dim: int = 128, + depth: int = 4, + num_heads: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + time_dim: int = 1, + location_dim: int = 2, + decoder_channels: int = 64, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError(f"prithvi_burnscars is segmentation-only, got task={task!r}.") + return PrithviBurnScars( + image_size=image_size, + in_channels=in_channels, + out_dim=out_dim, + patch_size=patch_size, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + time_dim=time_dim, + location_dim=location_dim, + decoder_channels=decoder_channels, + ) + + +__all__ = ["PrithviBurnScars", "prithvi_burnscars_builder"] diff --git a/pyhazards/models/prithvi_eo_2_tl.py b/pyhazards/models/prithvi_eo_2_tl.py new file mode 100644 index 00000000..09dd0ad3 --- /dev/null +++ b/pyhazards/models/prithvi_eo_2_tl.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class EOSequencePatchEmbed(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int): + super().__init__() + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=(1, patch_size, patch_size), + stride=(1, patch_size, patch_size), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, T, C, H, W) + x = x.permute(0, 2, 1, 3, 4) + return self.proj(x) + + +class PrithviEOBackbone(nn.Module): + """Lightweight temporal-location-aware EO backbone inspired by Prithvi-EO-2.0-TL.""" + + def __init__( + self, + image_size: int = 32, + in_channels: int = 6, + patch_size: int = 4, + embed_dim: int = 128, + depth: int = 4, + num_heads: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + time_dim: int = 1, + location_dim: int = 2, + ): + super().__init__() + if image_size % patch_size != 0: + raise ValueError(f"image_size={image_size} must be divisible by patch_size={patch_size}") + self.image_size = int(image_size) + self.in_channels = int(in_channels) + self.patch_size = int(patch_size) + self.embed_dim = int(embed_dim) + self.time_dim = int(time_dim) + self.location_dim = int(location_dim) + self.grid_size = self.image_size // self.patch_size + + self.patch_embed = EOSequencePatchEmbed( + in_channels=self.in_channels, + embed_dim=self.embed_dim, + patch_size=self.patch_size, + ) + self.spatial_pos_embed = nn.Parameter( + torch.zeros(1, self.grid_size * self.grid_size, self.embed_dim) + ) + nn.init.trunc_normal_(self.spatial_pos_embed, std=0.02) + + self.time_proj = nn.Linear(self.time_dim, self.embed_dim) + self.location_proj = nn.Linear(self.location_dim, self.embed_dim) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.embed_dim, + nhead=int(num_heads), + dim_feedforward=int(self.embed_dim * mlp_ratio), + dropout=float(dropout), + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=int(depth)) + self.norm = nn.LayerNorm(self.embed_dim) + + def _unpack_inputs(self, inputs: torch.Tensor | Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + if isinstance(inputs, dict): + x = inputs.get("x") + time_metadata = inputs.get("time_metadata") + location_metadata = inputs.get("location_metadata") + else: + x = inputs + time_metadata = None + location_metadata = None + + if not isinstance(x, torch.Tensor): + raise ValueError("PrithviEOBackbone expects a tensor input or a dict containing key 'x'.") + if x.ndim != 5: + raise ValueError( + "PrithviEOBackbone expects input shape (B, T, C, H, W), " + f"got {tuple(x.shape)}." + ) + if x.size(2) != self.in_channels: + raise ValueError( + f"PrithviEOBackbone expected in_channels={self.in_channels}, got {x.size(2)}." + ) + if x.size(-1) != self.image_size or x.size(-2) != self.image_size: + raise ValueError( + f"PrithviEOBackbone expected spatial size {self.image_size}x{self.image_size}, " + f"got {tuple(x.shape[-2:])}." + ) + return x, time_metadata, location_metadata + + def _build_time_metadata(self, batch: int, timesteps: int, device: torch.device, meta: torch.Tensor | None) -> torch.Tensor: + if meta is None: + base = torch.linspace(0.0, 1.0, timesteps, device=device).view(1, timesteps, 1) + return base.expand(batch, -1, -1) + if meta.ndim == 2: + meta = meta.unsqueeze(-1) + if meta.ndim != 3: + raise ValueError(f"time_metadata must have shape (B,T) or (B,T,D), got {tuple(meta.shape)}") + if meta.size(0) != batch or meta.size(1) != timesteps: + raise ValueError( + f"time_metadata expected batch/timestep=({batch},{timesteps}), got ({meta.size(0)},{meta.size(1)})" + ) + if meta.size(-1) == self.time_dim: + return meta.to(device=device, dtype=torch.float32) + if meta.size(-1) > self.time_dim: + return meta[..., : self.time_dim].to(device=device, dtype=torch.float32) + pad = torch.zeros(batch, timesteps, self.time_dim - meta.size(-1), device=device) + return torch.cat([meta.to(device=device, dtype=torch.float32), pad], dim=-1) + + def _build_location_metadata(self, batch: int, device: torch.device, meta: torch.Tensor | None) -> torch.Tensor: + if meta is None: + return torch.zeros(batch, self.location_dim, device=device) + if meta.ndim != 2: + raise ValueError(f"location_metadata must have shape (B,D), got {tuple(meta.shape)}") + if meta.size(0) != batch: + raise ValueError(f"location_metadata expected batch={batch}, got {meta.size(0)}") + if meta.size(-1) == self.location_dim: + return meta.to(device=device, dtype=torch.float32) + if meta.size(-1) > self.location_dim: + return meta[..., : self.location_dim].to(device=device, dtype=torch.float32) + pad = torch.zeros(batch, self.location_dim - meta.size(-1), device=device) + return torch.cat([meta.to(device=device, dtype=torch.float32), pad], dim=-1) + + def forward(self, inputs: torch.Tensor | Dict[str, Any]) -> torch.Tensor: + x, time_metadata, location_metadata = self._unpack_inputs(inputs) + batch, timesteps, _, height, width = x.shape + device = x.device + + feat = self.patch_embed(x) + _, _, _, h_tokens, w_tokens = feat.shape + tokens = feat.permute(0, 2, 3, 4, 1).reshape(batch, timesteps * h_tokens * w_tokens, self.embed_dim) + + spatial_pos = self.spatial_pos_embed.unsqueeze(1).expand(-1, timesteps, -1, -1) + spatial_pos = spatial_pos.reshape(1, timesteps * h_tokens * w_tokens, self.embed_dim) + tokens = tokens + spatial_pos + + time_meta = self._build_time_metadata(batch, timesteps, device, time_metadata) + time_tokens = self.time_proj(time_meta).unsqueeze(2).expand(-1, -1, h_tokens * w_tokens, -1) + time_tokens = time_tokens.reshape(batch, timesteps * h_tokens * w_tokens, self.embed_dim) + tokens = tokens + time_tokens + + location_meta = self._build_location_metadata(batch, device, location_metadata) + tokens = tokens + self.location_proj(location_meta).unsqueeze(1) + + encoded = self.norm(self.encoder(tokens)) + encoded = encoded.reshape(batch, timesteps, h_tokens, w_tokens, self.embed_dim).mean(dim=1) + return encoded.permute(0, 3, 1, 2).contiguous() + + +class PrithviEO2TL(nn.Module): + """Temporal-location-aware EO segmentation model inspired by Prithvi-EO-2.0-TL.""" + + def __init__( + self, + image_size: int = 32, + in_channels: int = 6, + out_dim: int = 1, + patch_size: int = 4, + embed_dim: int = 128, + depth: int = 4, + num_heads: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + time_dim: int = 1, + location_dim: int = 2, + decoder_channels: int = 64, + ): + super().__init__() + if out_dim <= 0: + raise ValueError(f"out_dim must be positive, got {out_dim}") + self.image_size = int(image_size) + self.backbone = PrithviEOBackbone( + image_size=image_size, + in_channels=in_channels, + patch_size=patch_size, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + time_dim=time_dim, + location_dim=location_dim, + ) + self.decoder = nn.Sequential( + nn.Conv2d(embed_dim, decoder_channels, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(decoder_channels, decoder_channels, kernel_size=3, padding=1), + nn.GELU(), + ) + self.head = nn.Conv2d(decoder_channels, out_dim, kernel_size=1) + + def forward(self, inputs: torch.Tensor | Dict[str, Any]) -> torch.Tensor: + if isinstance(inputs, dict): + x = inputs["x"] + else: + x = inputs + features = self.backbone(inputs) + logits = self.head(self.decoder(features)) + return F.interpolate(logits, size=x.shape[-2:], mode="bilinear", align_corners=False) + +def prithvi_eo_2_tl_builder( + task: str, + image_size: int = 32, + in_channels: int = 6, + out_dim: int = 1, + patch_size: int = 4, + embed_dim: int = 128, + depth: int = 4, + num_heads: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + time_dim: int = 1, + location_dim: int = 2, + decoder_channels: int = 64, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError(f"prithvi_eo_2_tl is segmentation-only, got task={task!r}.") + return PrithviEO2TL( + image_size=image_size, + in_channels=in_channels, + out_dim=out_dim, + patch_size=patch_size, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + time_dim=time_dim, + location_dim=location_dim, + decoder_channels=decoder_channels, + ) + + +__all__ = ["PrithviEOBackbone", "PrithviEO2TL", "prithvi_eo_2_tl_builder"] diff --git a/pyhazards/models/prithvi_wxc.py b/pyhazards/models/prithvi_wxc.py new file mode 100644 index 00000000..dddf4c8d --- /dev/null +++ b/pyhazards/models/prithvi_wxc.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class WeatherSequencePatchEmbed(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int): + super().__init__() + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=(1, patch_size, patch_size), + stride=(1, patch_size, patch_size), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, T, C, H, W) + return self.proj(x.permute(0, 2, 1, 3, 4)) + + +class PrithviWxCBackbone(nn.Module): + """Weather-climate backbone inspired by Prithvi-WxC.""" + + def __init__( + self, + image_size: int = 32, + in_channels: int = 8, + patch_size: int = 4, + embed_dim: int = 128, + depth: int = 4, + num_heads: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + lead_time_dim: int = 1, + variable_summary_dim: int = 8, + ): + super().__init__() + if image_size % patch_size != 0: + raise ValueError(f"image_size={image_size} must be divisible by patch_size={patch_size}") + + self.image_size = int(image_size) + self.in_channels = int(in_channels) + self.patch_size = int(patch_size) + self.embed_dim = int(embed_dim) + self.lead_time_dim = int(lead_time_dim) + self.variable_summary_dim = int(variable_summary_dim) + self.grid_size = self.image_size // self.patch_size + + self.patch_embed = WeatherSequencePatchEmbed( + in_channels=self.in_channels, + embed_dim=self.embed_dim, + patch_size=self.patch_size, + ) + self.spatial_pos_embed = nn.Parameter( + torch.zeros(1, self.grid_size * self.grid_size, self.embed_dim) + ) + nn.init.trunc_normal_(self.spatial_pos_embed, std=0.02) + + self.lead_time_proj = nn.Linear(self.lead_time_dim, self.embed_dim) + self.variable_proj = nn.Linear(self.variable_summary_dim, self.embed_dim) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.embed_dim, + nhead=int(num_heads), + dim_feedforward=int(self.embed_dim * mlp_ratio), + dropout=float(dropout), + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=int(depth)) + self.norm = nn.LayerNorm(self.embed_dim) + + def _unpack_inputs( + self, + inputs: torch.Tensor | Dict[str, Any], + ) -> Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + if isinstance(inputs, dict): + x = inputs.get("x") + lead_time = inputs.get("lead_time_hours") + variable_summary = inputs.get("variable_summary") + else: + x = inputs + lead_time = None + variable_summary = None + + if not isinstance(x, torch.Tensor): + raise ValueError("PrithviWxCBackbone expects a tensor input or a dict containing key 'x'.") + if x.ndim != 5: + raise ValueError( + "PrithviWxCBackbone expects input shape (B, T, C, H, W), " + f"got {tuple(x.shape)}." + ) + if x.size(2) != self.in_channels: + raise ValueError( + f"PrithviWxCBackbone expected in_channels={self.in_channels}, got {x.size(2)}." + ) + if x.size(-1) != self.image_size or x.size(-2) != self.image_size: + raise ValueError( + f"PrithviWxCBackbone expected spatial size {self.image_size}x{self.image_size}, " + f"got {tuple(x.shape[-2:])}." + ) + return x, lead_time, variable_summary + + def _build_lead_time( + self, + batch: int, + timesteps: int, + device: torch.device, + lead_time: torch.Tensor | None, + ) -> torch.Tensor: + if lead_time is None: + base = torch.linspace(0.0, 1.0, timesteps, device=device).view(1, timesteps, 1) + return base.expand(batch, -1, -1) + if lead_time.ndim == 1: + lead_time = lead_time.view(batch, 1, 1).expand(-1, timesteps, -1) + elif lead_time.ndim == 2: + lead_time = lead_time.unsqueeze(-1) + if lead_time.ndim != 3: + raise ValueError( + f"lead_time_hours must have shape (B,), (B,T), or (B,T,D), got {tuple(lead_time.shape)}" + ) + if lead_time.size(0) != batch or lead_time.size(1) != timesteps: + raise ValueError( + "lead_time_hours batch/timestep mismatch: " + f"expected ({batch},{timesteps}), got ({lead_time.size(0)},{lead_time.size(1)})" + ) + if lead_time.size(-1) == self.lead_time_dim: + return lead_time.to(device=device, dtype=torch.float32) + if lead_time.size(-1) > self.lead_time_dim: + return lead_time[..., : self.lead_time_dim].to(device=device, dtype=torch.float32) + pad = torch.zeros(batch, timesteps, self.lead_time_dim - lead_time.size(-1), device=device) + return torch.cat([lead_time.to(device=device, dtype=torch.float32), pad], dim=-1) + + def _build_variable_summary( + self, + x: torch.Tensor, + variable_summary: torch.Tensor | None, + ) -> torch.Tensor: + batch, timesteps, channels, _, _ = x.shape + if variable_summary is None: + summary = x.mean(dim=(-1, -2)) + else: + summary = variable_summary + if summary.ndim == 2: + summary = summary.unsqueeze(1).expand(-1, timesteps, -1) + if summary.ndim != 3: + raise ValueError( + "variable_summary must have shape (B,D) or (B,T,D), " + f"got {tuple(summary.shape)}" + ) + if summary.size(0) != batch or summary.size(1) != timesteps: + raise ValueError( + "variable_summary batch/timestep mismatch: " + f"expected ({batch},{timesteps}), got ({summary.size(0)},{summary.size(1)})" + ) + if summary.size(-1) > self.variable_summary_dim: + summary = summary[..., : self.variable_summary_dim] + elif summary.size(-1) < self.variable_summary_dim: + pad = torch.zeros( + batch, + timesteps, + self.variable_summary_dim - summary.size(-1), + device=x.device, + dtype=torch.float32, + ) + summary = torch.cat([summary.to(device=x.device, dtype=torch.float32), pad], dim=-1) + else: + summary = summary.to(device=x.device, dtype=torch.float32) + return summary + + def forward(self, inputs: torch.Tensor | Dict[str, Any]) -> torch.Tensor: + x, lead_time, variable_summary = self._unpack_inputs(inputs) + batch, timesteps, _, _, _ = x.shape + + feat = self.patch_embed(x) + _, _, _, h_tokens, w_tokens = feat.shape + tokens = feat.permute(0, 2, 3, 4, 1).reshape(batch, timesteps * h_tokens * w_tokens, self.embed_dim) + + spatial_pos = self.spatial_pos_embed.unsqueeze(1).expand(-1, timesteps, -1, -1) + spatial_pos = spatial_pos.reshape(1, timesteps * h_tokens * w_tokens, self.embed_dim) + tokens = tokens + spatial_pos + + lead = self._build_lead_time(batch, timesteps, x.device, lead_time) + lead_tokens = self.lead_time_proj(lead).unsqueeze(2).expand(-1, -1, h_tokens * w_tokens, -1) + tokens = tokens + lead_tokens.reshape(batch, timesteps * h_tokens * w_tokens, self.embed_dim) + + var_summary = self._build_variable_summary(x, variable_summary) + variable_tokens = self.variable_proj(var_summary).unsqueeze(2).expand(-1, -1, h_tokens * w_tokens, -1) + tokens = tokens + variable_tokens.reshape(batch, timesteps * h_tokens * w_tokens, self.embed_dim) + + encoded = self.norm(self.encoder(tokens)) + encoded = encoded.reshape(batch, timesteps, h_tokens, w_tokens, self.embed_dim).mean(dim=1) + return encoded.permute(0, 3, 1, 2).contiguous() + + +class PrithviWxC(nn.Module): + """Dense wildfire-risk head on top of a Prithvi-WxC-style weather backbone.""" + + def __init__( + self, + image_size: int = 32, + in_channels: int = 8, + out_dim: int = 1, + patch_size: int = 4, + embed_dim: int = 128, + depth: int = 4, + num_heads: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + lead_time_dim: int = 1, + variable_summary_dim: int = 8, + decoder_channels: int = 64, + ): + super().__init__() + if out_dim <= 0: + raise ValueError(f"out_dim must be positive, got {out_dim}") + self.backbone = PrithviWxCBackbone( + image_size=image_size, + in_channels=in_channels, + patch_size=patch_size, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + lead_time_dim=lead_time_dim, + variable_summary_dim=variable_summary_dim, + ) + self.decoder = nn.Sequential( + nn.Conv2d(embed_dim, decoder_channels, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(decoder_channels, decoder_channels, kernel_size=3, padding=1), + nn.GELU(), + ) + self.head = nn.Conv2d(decoder_channels, out_dim, kernel_size=1) + + def forward(self, inputs: torch.Tensor | Dict[str, Any]) -> torch.Tensor: + x = inputs["x"] if isinstance(inputs, dict) else inputs + features = self.backbone(inputs) + logits = self.head(self.decoder(features)) + return F.interpolate(logits, size=x.shape[-2:], mode="bilinear", align_corners=False) + + +def prithvi_wxc_builder( + task: str, + image_size: int = 32, + in_channels: int = 8, + out_dim: int = 1, + patch_size: int = 4, + embed_dim: int = 128, + depth: int = 4, + num_heads: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + lead_time_dim: int = 1, + variable_summary_dim: int = 8, + decoder_channels: int = 64, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError(f"prithvi_wxc is segmentation-only, got task={task!r}.") + return PrithviWxC( + image_size=image_size, + in_channels=in_channels, + out_dim=out_dim, + patch_size=patch_size, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + lead_time_dim=lead_time_dim, + variable_summary_dim=variable_summary_dim, + decoder_channels=decoder_channels, + ) + + +__all__ = ["PrithviWxCBackbone", "PrithviWxC", "prithvi_wxc_builder"] diff --git a/pyhazards/models/qwen25_vl_wildfire_prompted.py b/pyhazards/models/qwen25_vl_wildfire_prompted.py new file mode 100644 index 00000000..dbb16810 --- /dev/null +++ b/pyhazards/models/qwen25_vl_wildfire_prompted.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn + + +class Qwen25VLWildfirePrompted(nn.Module): + """Prompt-conditioned wildfire segmentation model inspired by Qwen2.5-VL.""" + + def __init__( + self, + in_channels: int = 6, + out_dim: int = 1, + hidden_dim: int = 64, + prompt_dim: int = 24, + num_prompt_tokens: int = 4, + num_heads: int = 4, + dropout: float = 0.1, + ): + super().__init__() + if in_channels <= 0: + raise ValueError(f"in_channels must be positive, got {in_channels}") + if out_dim <= 0: + raise ValueError(f"out_dim must be positive, got {out_dim}") + if hidden_dim <= 0: + raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") + if prompt_dim <= 0: + raise ValueError(f"prompt_dim must be positive, got {prompt_dim}") + if num_prompt_tokens <= 0: + raise ValueError(f"num_prompt_tokens must be positive, got {num_prompt_tokens}") + if hidden_dim % num_heads != 0: + raise ValueError(f"hidden_dim={hidden_dim} must be divisible by num_heads={num_heads}") + + self.in_channels = int(in_channels) + self.prompt_dim = int(prompt_dim) + self.num_prompt_tokens = int(num_prompt_tokens) + self.hidden_dim = int(hidden_dim) + + self.visual_encoder = nn.Sequential( + nn.Conv2d(in_channels, hidden_dim // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + ) + self.prompt_proj = nn.Linear(self.prompt_dim, hidden_dim) + self.prompt_bank = nn.Parameter(torch.randn(self.num_prompt_tokens, self.prompt_dim) * 0.02) + self.image_summary = nn.Linear(hidden_dim, hidden_dim) + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_dim, + num_heads=num_heads, + dropout=dropout, + batch_first=True, + ) + self.ffn = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, hidden_dim * 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 2, hidden_dim), + ) + self.decoder = nn.Sequential( + nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(), + nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=3, padding=1), + nn.GELU(), + ) + self.head = nn.Conv2d(hidden_dim // 2, out_dim, kernel_size=1) + + def _unpack_inputs(self, inputs: torch.Tensor | Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor | None]: + if isinstance(inputs, dict): + x = inputs.get("x") + prompt = inputs.get("prompt_context") + else: + x = inputs + prompt = None + + if not isinstance(x, torch.Tensor): + raise ValueError("Qwen25VLWildfirePrompted expects a tensor input or a dict containing key 'x'.") + if x.ndim != 4: + raise ValueError( + "Qwen25VLWildfirePrompted expects input shape (B, C, H, W), " + f"got {tuple(x.shape)}." + ) + if x.size(1) != self.in_channels: + raise ValueError( + f"Qwen25VLWildfirePrompted expected in_channels={self.in_channels}, got {x.size(1)}." + ) + return x, prompt + + def _coerce_prompt(self, prompt: torch.Tensor | None, batch: int, device: torch.device) -> torch.Tensor: + learned = self.prompt_bank.unsqueeze(0).expand(batch, -1, -1) + if prompt is None: + return self.prompt_proj(learned) + if prompt.ndim == 2: + if prompt.size(0) != batch: + raise ValueError(f"prompt_context must have shape (B,D) or (B,T,D), got {tuple(prompt.shape)}") + prompt = prompt.unsqueeze(1).expand(-1, self.num_prompt_tokens, -1) + elif prompt.ndim == 3: + if prompt.size(0) != batch: + raise ValueError(f"prompt_context must have shape (B,D) or (B,T,D), got {tuple(prompt.shape)}") + if prompt.size(1) != self.num_prompt_tokens: + if prompt.size(1) > self.num_prompt_tokens: + prompt = prompt[:, : self.num_prompt_tokens] + else: + pad = torch.zeros(batch, self.num_prompt_tokens - prompt.size(1), prompt.size(2), device=prompt.device) + prompt = torch.cat([prompt, pad], dim=1) + else: + raise ValueError(f"prompt_context must have rank 2 or 3, got {tuple(prompt.shape)}") + + prompt = prompt.to(device=device, dtype=torch.float32) + if prompt.size(-1) > self.prompt_dim: + prompt = prompt[..., : self.prompt_dim] + elif prompt.size(-1) < self.prompt_dim: + pad = torch.zeros(batch, self.num_prompt_tokens, self.prompt_dim - prompt.size(-1), device=device) + prompt = torch.cat([prompt, pad], dim=-1) + return self.prompt_proj(prompt + learned.to(device=device, dtype=torch.float32)) + + def forward(self, inputs: torch.Tensor | Dict[str, Any]) -> torch.Tensor: + x, prompt = self._unpack_inputs(inputs) + batch = x.size(0) + device = x.device + + feature_map = self.visual_encoder(x) + visual_tokens = feature_map.flatten(2).transpose(1, 2) + pooled = torch.mean(visual_tokens, dim=1, keepdim=True) + pooled = self.image_summary(pooled) + + prompt_tokens = self._coerce_prompt(prompt, batch, device) + query_tokens = torch.cat([prompt_tokens, pooled], dim=1) + attn_out, _ = self.cross_attn(query_tokens, visual_tokens, visual_tokens, need_weights=False) + fused_tokens = attn_out + self.ffn(attn_out) + global_token = fused_tokens.mean(dim=1) + + context_map = global_token.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, feature_map.size(-2), feature_map.size(-1)) + decoded = self.decoder(torch.cat([feature_map, context_map], dim=1)) + return self.head(decoded) + + +def qwen25_vl_wildfire_prompted_builder( + task: str, + in_channels: int = 6, + out_dim: int = 1, + hidden_dim: int = 64, + prompt_dim: int = 24, + num_prompt_tokens: int = 4, + num_heads: int = 4, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError( + f"qwen25_vl_wildfire_prompted is segmentation-only in PyHazards, got task={task!r}." + ) + return Qwen25VLWildfirePrompted( + in_channels=in_channels, + out_dim=out_dim, + hidden_dim=hidden_dim, + prompt_dim=prompt_dim, + num_prompt_tokens=num_prompt_tokens, + num_heads=num_heads, + dropout=dropout, + ) + + +__all__ = ["Qwen25VLWildfirePrompted", "qwen25_vl_wildfire_prompted_builder"] diff --git a/pyhazards/models/rainformer.py b/pyhazards/models/rainformer.py new file mode 100644 index 00000000..4a3f12c4 --- /dev/null +++ b/pyhazards/models/rainformer.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class RainformerTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + hidden_channels: int = 16 + num_heads: int = 4 + num_layers: int = 2 + lr: float = 3e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class RainMixer(nn.Module): + def __init__(self, channels: int): + super().__init__() + self.dw = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels) + self.pw = nn.Conv2d(channels, channels, kernel_size=1) + self.act = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.act(self.pw(self.dw(x))) + + +class RainTemporalBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, dropout: float = 0.0): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True) + self.norm2 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim * 2, dim), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [BHW,T,D] + y = self.norm1(x) + y, _ = self.attn(y, y, y, need_weights=False) + x = x + y + x = x + self.ffn(self.norm2(x)) + return x + + +class TinyRainformer(nn.Module): + def __init__(self, in_channels: int = 1, hidden_channels: int = 16, num_heads: int = 4, num_layers: int = 2): + super().__init__() + self.hidden_channels = hidden_channels + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + self.rain_mixer = RainMixer(hidden_channels) + self.temporal_blocks = nn.ModuleList([RainTemporalBlock(hidden_channels, num_heads) for _ in range(num_layers)]) + + self.decoder = nn.Sequential( + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + b, t, c, h, w = x_seq.shape + + x = x_seq.reshape(b * t, c, h, w) + x = self.encoder(x) + x = self.rain_mixer(x) + d = x.shape[1] + + x = x.reshape(b, t, d, h, w) + x = x.permute(0, 3, 4, 1, 2).contiguous() # [B,H,W,T,D] + x = x.reshape(b * h * w, t, d) + + for blk in self.temporal_blocks: + x = blk(x) + + x = x[:, -1, :].reshape(b, h, w, d).permute(0, 3, 1, 2).contiguous() + return self.decoder(x) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_rainformer_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: RainformerTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyRainformer( + in_channels=cfg.in_channels, + hidden_channels=cfg.hidden_channels, + num_heads=cfg.num_heads, + num_layers=cfg.num_layers, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("Rainformer Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: RainformerTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "rainformer", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = RainformerTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_rainformer_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "rainformer_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Rainformer Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"rainformer_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] rainformer synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def rainformer_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "rainformer") + init_kwargs = filter_init_kwargs(TinyRainformer, {"in_channels": int(in_channels), **kwargs}) + model = TinyRainformer(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyRainformer", "rainformer_builder"] diff --git a/pyhazards/models/random_forest.py b/pyhazards/models/random_forest.py new file mode 100644 index 00000000..3f2e5686 --- /dev/null +++ b/pyhazards/models/random_forest.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, Optional + +import numpy as np +import torch.nn as nn + +from ._wildfire_benchmark_utils import EstimatorPort, filter_init_kwargs, require_task + + +class RandomForestModel(EstimatorPort): + """A tree-ensemble baseline for wildfire occurrence probability over tabular features.""" + + def __init__(self, n_estimators: int = 500, max_depth: Optional[int] = None, class_weight: Any = "balanced_subsample"): + super().__init__() + from sklearn.ensemble import RandomForestClassifier + + self.estimator = RandomForestClassifier( + n_estimators=int(n_estimators), + max_depth=max_depth, + class_weight=class_weight, + random_state=42, + n_jobs=1, + ) + + def _fit_numpy( + self, + x_train: np.ndarray, + y_train: np.ndarray, + x_val: Optional[np.ndarray], + y_val: Optional[np.ndarray], + ) -> None: + _ = x_val, y_val + self.estimator.fit(x_train, y_train) + + def _predict_positive_proba(self, x: np.ndarray) -> np.ndarray: + return self.estimator.predict_proba(x)[:, 1] + + +def random_forest_builder(task: str, **kwargs: Any) -> nn.Module: + require_task(task, {"classification"}, "random_forest") + build_kwargs = filter_init_kwargs(RandomForestModel, kwargs) + return RandomForestModel(**build_kwargs) + + +__all__ = ["RandomForestModel", "random_forest_builder"] diff --git a/pyhazards/models/resnet18_unet.py b/pyhazards/models/resnet18_unet.py new file mode 100644 index 00000000..a3d42a8a --- /dev/null +++ b/pyhazards/models/resnet18_unet.py @@ -0,0 +1,465 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.unet import ( + binary_ece, + make_synthetic_fire_maps, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .unet import ( + binary_ece, + make_synthetic_fire_maps, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class ResNet18UNetTrackOConfig: + in_channels: int = 1 + stem_channels: int = 16 + lr: float = 8e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_channels: int, out_channels: int, stride: int = 1): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels) + + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_channels), + ) + else: + self.downsample = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(identity) + + out += identity + out = self.relu(out) + return out + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels: int, skip_channels: int, out_channels: int): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels + skip_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False) + x = torch.cat([x, skip], dim=1) + return self.conv(x) + + +class TinyResNet18UNet(nn.Module): + def __init__(self, in_channels: int = 1, stem_channels: int = 16): + super().__init__() + c1, c2, c3, c4 = stem_channels, stem_channels * 2, stem_channels * 4, stem_channels * 8 + + self.stem = nn.Sequential( + nn.Conv2d(in_channels, c1, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(c1), + nn.ReLU(inplace=True), + ) + + self.layer1 = self._make_layer(c1, c1, blocks=2, stride=1) + self.layer2 = self._make_layer(c1, c2, blocks=2, stride=2) + self.layer3 = self._make_layer(c2, c3, blocks=2, stride=2) + self.layer4 = self._make_layer(c3, c4, blocks=2, stride=2) + + self.dec3 = DecoderBlock(in_channels=c4, skip_channels=c3, out_channels=c3) + self.dec2 = DecoderBlock(in_channels=c3, skip_channels=c2, out_channels=c2) + self.dec1 = DecoderBlock(in_channels=c2, skip_channels=c1, out_channels=c1) + + self.head = nn.Sequential( + nn.Conv2d(c1, c1, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(c1, 1, kernel_size=1), + ) + + def _make_layer(self, in_channels: int, out_channels: int, blocks: int, stride: int) -> nn.Sequential: + layers: List[nn.Module] = [BasicBlock(in_channels, out_channels, stride=stride)] + for _ in range(1, blocks): + layers.append(BasicBlock(out_channels, out_channels, stride=1)) + return nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x0 = self.stem(x) + x1 = self.layer1(x0) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + + y3 = self.dec3(x4, x3) + y2 = self.dec2(y3, x2) + y1 = self.dec1(y2, x1) + return self.head(y1) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_resnet18_unet_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: ResNet18UNetTrackOConfig, +): + if x_train.ndim != 4 or x_val.ndim != 4: + raise ValueError("x_train and x_val must be 4D arrays [N,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyResNet18UNet(in_channels=cfg.in_channels, stem_channels=cfg.stem_channels).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("ResNet18 U-Net Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: ResNet18UNetTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "resnet18_unet", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 192, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_maps(n_samples=n_samples, image_size=image_size, seed=seed) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = ResNet18UNetTrackOConfig( + seed=seed, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_resnet18_unet_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "resnet18_unet_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run ResNet18 U-Net Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=192) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"resnet18_unet_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] resnet18_unet synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def resnet18_unet_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "resnet18_unet") + init_kwargs = filter_init_kwargs(TinyResNet18UNet, {"in_channels": int(in_channels), **kwargs}) + model = TinyResNet18UNet(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyResNet18UNet", "resnet18_unet_builder"] diff --git a/pyhazards/models/segformer.py b/pyhazards/models/segformer.py new file mode 100644 index 00000000..440e0a66 --- /dev/null +++ b/pyhazards/models/segformer.py @@ -0,0 +1,568 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class SegFormerTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + embed_dims: Tuple[int, int] = (16, 32) + num_heads: Tuple[int, int] = (1, 2) + sr_ratios: Tuple[int, int] = (4, 2) + mlp_ratio: float = 2.0 + dropout: float = 0.1 + lr: float = 2e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int, stride: int): + super().__init__() + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=patch_size // 2, + ) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + x = self.proj(x) + b, c, h, w = x.shape + x = x.flatten(2).transpose(1, 2).contiguous() # [B,N,C] + x = self.norm(x) + return x, h, w + + +class MixFFN(nn.Module): + def __init__(self, dim: int, mlp_ratio: float, dropout: float): + super().__init__() + hidden = int(dim * mlp_ratio) + self.fc1 = nn.Linear(dim, hidden) + self.dwconv = nn.Conv2d(hidden, hidden, kernel_size=3, padding=1, groups=hidden) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.fc2 = nn.Linear(hidden, dim) + + def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: + x = self.fc1(x) + b, n, c = x.shape + x_img = x.transpose(1, 2).reshape(b, c, h, w) + x_img = self.dwconv(x_img) + x = x_img.flatten(2).transpose(1, 2).contiguous() + x = self.act(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class EfficientSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, sr_ratio: int, dropout: float): + super().__init__() + self.sr_ratio = sr_ratio + self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True) + self.norm = nn.LayerNorm(dim) + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + else: + self.sr = None + + def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: + kv = x + if self.sr is not None: + b, n, c = x.shape + x_img = x.transpose(1, 2).reshape(b, c, h, w) + x_img = self.sr(x_img) + kv = x_img.flatten(2).transpose(1, 2).contiguous() + kv = self.norm(kv) + + out, _ = self.attn(x, kv, kv, need_weights=False) + return out + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, sr_ratio: int, mlp_ratio: float, dropout: float): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = EfficientSelfAttention(dim=dim, num_heads=num_heads, sr_ratio=sr_ratio, dropout=dropout) + self.norm2 = nn.LayerNorm(dim) + self.ffn = MixFFN(dim=dim, mlp_ratio=mlp_ratio, dropout=dropout) + + def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: + x = x + self.attn(self.norm1(x), h, w) + x = x + self.ffn(self.norm2(x), h, w) + return x + + +class TinySegFormerEncoder(nn.Module): + def __init__( + self, + in_channels: int = 1, + embed_dims: Tuple[int, int] = (16, 32), + num_heads: Tuple[int, int] = (1, 2), + sr_ratios: Tuple[int, int] = (4, 2), + mlp_ratio: float = 2.0, + dropout: float = 0.1, + ): + super().__init__() + d1, d2 = embed_dims + + self.patch1 = OverlapPatchEmbed(in_channels=in_channels, embed_dim=d1, patch_size=7, stride=2) + self.block1 = TransformerEncoderBlock( + dim=d1, + num_heads=num_heads[0], + sr_ratio=sr_ratios[0], + mlp_ratio=mlp_ratio, + dropout=dropout, + ) + + self.patch2 = OverlapPatchEmbed(in_channels=d1, embed_dim=d2, patch_size=3, stride=2) + self.block2 = TransformerEncoderBlock( + dim=d2, + num_heads=num_heads[1], + sr_ratio=sr_ratios[1], + mlp_ratio=mlp_ratio, + dropout=dropout, + ) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x1, h1, w1 = self.patch1(x) + x1 = self.block1(x1, h1, w1) + f1 = x1.transpose(1, 2).reshape(x1.shape[0], -1, h1, w1) + + x2, h2, w2 = self.patch2(f1) + x2 = self.block2(x2, h2, w2) + f2 = x2.transpose(1, 2).reshape(x2.shape[0], -1, h2, w2) + return f1, f2 + + +class SegFormerHead(nn.Module): + def __init__(self, in_dims: Tuple[int, int], decoder_dim: int = 32, dropout: float = 0.1): + super().__init__() + self.proj1 = nn.Conv2d(in_dims[0], decoder_dim, kernel_size=1) + self.proj2 = nn.Conv2d(in_dims[1], decoder_dim, kernel_size=1) + self.fuse = nn.Sequential( + nn.Conv2d(decoder_dim * 2, decoder_dim, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Dropout2d(dropout), + nn.Conv2d(decoder_dim, 1, kernel_size=1), + ) + + def forward(self, f1: torch.Tensor, f2: torch.Tensor, out_hw: Tuple[int, int]) -> torch.Tensor: + p1 = self.proj1(f1) + p2 = self.proj2(f2) + p2 = F.interpolate(p2, size=p1.shape[-2:], mode="bilinear", align_corners=False) + logits_small = self.fuse(torch.cat([p1, p2], dim=1)) + return F.interpolate(logits_small, size=out_hw, mode="bilinear", align_corners=False) + + +class TinySegFormer(nn.Module): + def __init__( + self, + in_channels: int = 1, + embed_dims: Tuple[int, int] = (16, 32), + num_heads: Tuple[int, int] = (1, 2), + sr_ratios: Tuple[int, int] = (4, 2), + mlp_ratio: float = 2.0, + dropout: float = 0.1, + ): + super().__init__() + self.encoder = TinySegFormerEncoder( + in_channels=in_channels, + embed_dims=embed_dims, + num_heads=num_heads, + sr_ratios=sr_ratios, + mlp_ratio=mlp_ratio, + dropout=dropout, + ) + self.decode_head = SegFormerHead(in_dims=embed_dims, decoder_dim=embed_dims[1], dropout=dropout) + + def _temporal_fusion(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + frame_scores = x_seq.mean(dim=(2, 3, 4)) + weights = torch.softmax(frame_scores, dim=1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + return torch.sum(x_seq * weights, dim=1) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + _, _, _, h, w = x_seq.shape + x_img = self._temporal_fusion(x_seq) + f1, f2 = self.encoder(x_img) + return self.decode_head(f1, f2, out_hw=(h, w)) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_segformer_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: SegFormerTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinySegFormer( + in_channels=cfg.in_channels, + embed_dims=cfg.embed_dims, + num_heads=cfg.num_heads, + sr_ratios=cfg.sr_ratios, + mlp_ratio=cfg.mlp_ratio, + dropout=cfg.dropout, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("SegFormer Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: SegFormerTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "segformer", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = SegFormerTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_segformer_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "segformer_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run SegFormer Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"segformer_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] segformer synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def segformer_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "segformer") + init_kwargs = filter_init_kwargs(TinySegFormer, {"in_channels": int(in_channels), **kwargs}) + model = TinySegFormer(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinySegFormer", "segformer_builder"] diff --git a/pyhazards/models/swin_unet.py b/pyhazards/models/swin_unet.py new file mode 100644 index 00000000..dd617281 --- /dev/null +++ b/pyhazards/models/swin_unet.py @@ -0,0 +1,527 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class SwinUNetTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + embed_dims: Tuple[int, int] = (16, 32) + num_heads: Tuple[int, int] = (1, 2) + window_size: int = 3 + mlp_ratio: float = 2.0 + dropout: float = 0.1 + lr: float = 2e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +def _window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, int, int]: + # x: [B,H,W,C] -> windows: [B*nw, ws*ws, C] + b, h, w, c = x.shape + pad_h = (window_size - h % window_size) % window_size + pad_w = (window_size - w % window_size) % window_size + + if pad_h > 0 or pad_w > 0: + x = x.permute(0, 3, 1, 2).contiguous() + x = F.pad(x, (0, pad_w, 0, pad_h)) + x = x.permute(0, 2, 3, 1).contiguous() + + hp, wp = h + pad_h, w + pad_w + x = x.view(b, hp // window_size, window_size, wp // window_size, window_size, c) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = x.view(-1, window_size * window_size, c) + return windows, hp, wp + + +def _window_reverse(windows: torch.Tensor, window_size: int, hp: int, wp: int, b: int) -> torch.Tensor: + # windows: [B*nw, ws*ws, C] -> [B,Hp,Wp,C] + c = windows.shape[-1] + x = windows.view(b, hp // window_size, wp // window_size, window_size, window_size, c) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + return x.view(b, hp, wp, c) + + +class SwinBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, window_size: int, shift_size: int, mlp_ratio: float, dropout: float): + super().__init__() + self.window_size = window_size + self.shift_size = shift_size + + self.norm1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True) + self.norm2 = nn.LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(int(dim * mlp_ratio), dim), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B,C,H,W] + b, c, h, w = x.shape + residual = x + x = x.permute(0, 2, 3, 1).contiguous() # [B,H,W,C] + + if self.shift_size > 0: + x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + + windows, hp, wp = _window_partition(x, self.window_size) + w_norm = self.norm1(windows) + attn_out, _ = self.attn(w_norm, w_norm, w_norm, need_weights=False) + windows = windows + attn_out + x = _window_reverse(windows, self.window_size, hp, wp, b) + + if self.shift_size > 0: + x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = x[:, :h, :w, :].contiguous() + x = x.permute(0, 3, 1, 2).contiguous() + x = residual + x + + tokens = x.flatten(2).transpose(1, 2).contiguous() # [B,HW,C] + tokens = tokens + self.mlp(self.norm2(tokens)) + x = tokens.transpose(1, 2).reshape(b, c, h, w).contiguous() + return x + + +class TinySwinUNet(nn.Module): + def __init__( + self, + in_channels: int = 1, + embed_dims: Tuple[int, int] = (16, 32), + num_heads: Tuple[int, int] = (1, 2), + window_size: int = 3, + mlp_ratio: float = 2.0, + dropout: float = 0.1, + ): + super().__init__() + c1, c2 = embed_dims + shift = max(1, window_size // 2) + + self.patch_embed = nn.Conv2d(in_channels, c1, kernel_size=3, stride=2, padding=1) + self.stage1 = nn.Sequential( + SwinBlock(c1, num_heads[0], window_size, shift_size=0, mlp_ratio=mlp_ratio, dropout=dropout), + SwinBlock(c1, num_heads[0], window_size, shift_size=shift, mlp_ratio=mlp_ratio, dropout=dropout), + ) + + self.downsample = nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1) + self.stage2 = nn.Sequential( + SwinBlock(c2, num_heads[1], window_size, shift_size=0, mlp_ratio=mlp_ratio, dropout=dropout), + SwinBlock(c2, num_heads[1], window_size, shift_size=shift, mlp_ratio=mlp_ratio, dropout=dropout), + ) + + self.up1 = nn.ConvTranspose2d(c2, c1, kernel_size=2, stride=2) + self.fuse1 = nn.Sequential( + nn.Conv2d(c1 * 2, c1, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + self.up2 = nn.ConvTranspose2d(c1, c1, kernel_size=2, stride=2) + self.head = nn.Sequential( + nn.Conv2d(c1, c1, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(c1, 1, kernel_size=1), + ) + + def _temporal_fusion(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + frame_scores = x_seq.mean(dim=(2, 3, 4)) + weights = torch.softmax(frame_scores, dim=1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + return torch.sum(x_seq * weights, dim=1) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + _, _, _, h, w = x_seq.shape + x = self._temporal_fusion(x_seq) # [B,C,H,W] + + x1 = self.patch_embed(x) # [B,C1,H/2,W/2] + x1 = self.stage1(x1) + + x2 = self.downsample(x1) # [B,C2,H/4,W/4] + x2 = self.stage2(x2) + + u1 = self.up1(x2) # [B,C1,H/2,W/2] + if u1.shape[-2:] != x1.shape[-2:]: + u1 = F.interpolate(u1, size=x1.shape[-2:], mode="bilinear", align_corners=False) + f1 = self.fuse1(torch.cat([u1, x1], dim=1)) + + u2 = self.up2(f1) # [B,C1,H,W] + logits = self.head(u2) + if logits.shape[-2:] != (h, w): + logits = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False) + return logits + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_swin_unet_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: SwinUNetTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinySwinUNet( + in_channels=cfg.in_channels, + embed_dims=cfg.embed_dims, + num_heads=cfg.num_heads, + window_size=cfg.window_size, + mlp_ratio=cfg.mlp_ratio, + dropout=cfg.dropout, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("Swin-UNet Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: SwinUNetTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "swin_unet", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = SwinUNetTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_swin_unet_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "swin_unet_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Swin-UNet Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"swin_unet_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] swin_unet synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def swin_unet_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "swin_unet") + init_kwargs = filter_init_kwargs(TinySwinUNet, {"in_channels": int(in_channels), **kwargs}) + model = TinySwinUNet(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinySwinUNet", "swin_unet_builder"] diff --git a/pyhazards/models/swinlstm.py b/pyhazards/models/swinlstm.py new file mode 100644 index 00000000..491e3cb5 --- /dev/null +++ b/pyhazards/models/swinlstm.py @@ -0,0 +1,506 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class SwinLSTMTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + embed_dim: int = 16 + hidden_channels: int = 16 + num_heads: int = 4 + window_size: int = 3 + lr: float = 1e-3 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +def _window_partition(x: torch.Tensor, window_size: int) -> torch.Tensor: + # x: [B, H, W, C] + b, h, w, c = x.shape + x = x.view(b, h // window_size, window_size, w // window_size, window_size, c) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = x.view(-1, window_size * window_size, c) + return windows + + +def _window_reverse(windows: torch.Tensor, window_size: int, h: int, w: int, b: int) -> torch.Tensor: + # windows: [B*num_windows, window_size*window_size, C] + c = windows.shape[-1] + x = windows.view(b, h // window_size, w // window_size, window_size, window_size, c) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.view(b, h, w, c) + return x + + +class WindowAttentionBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, window_size: int, shift_size: int = 0): + super().__init__() + self.window_size = window_size + self.shift_size = shift_size + + self.norm1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) + self.norm2 = nn.LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 2), + nn.GELU(), + nn.Linear(dim * 2, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B, C, H, W] + b, c, h, w = x.shape + ws = self.window_size + + x = x.permute(0, 2, 3, 1).contiguous() # [B,H,W,C] + if self.shift_size > 0: + x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + + windows = _window_partition(x, ws) # [B*nw, ws*ws, C] + w_norm = self.norm1(windows) + attn_out, _ = self.attn(w_norm, w_norm, w_norm, need_weights=False) + windows = windows + attn_out + windows = windows + self.mlp(self.norm2(windows)) + + x = _window_reverse(windows, ws, h, w, b) + if self.shift_size > 0: + x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + + return x.permute(0, 3, 1, 2).contiguous() # [B,C,H,W] + + +class ConvLSTMCell(nn.Module): + def __init__(self, input_channels: int, hidden_channels: int): + super().__init__() + self.hidden_channels = hidden_channels + self.conv = nn.Conv2d(input_channels + hidden_channels, hidden_channels * 4, kernel_size=3, padding=1) + + def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + fused = torch.cat([x_t, h_prev], dim=1) + gates = self.conv(fused) + i, f, o, g = torch.chunk(gates, 4, dim=1) + i = torch.sigmoid(i) + f = torch.sigmoid(f) + o = torch.sigmoid(o) + g = torch.tanh(g) + + c = f * c_prev + i * g + h = o * torch.tanh(c) + return h, c + + +class TinySwinLSTM(nn.Module): + def __init__( + self, + in_channels: int = 1, + embed_dim: int = 16, + hidden_channels: int = 16, + num_heads: int = 4, + window_size: int = 3, + ): + super().__init__() + self.embed_dim = embed_dim + self.hidden_channels = hidden_channels + + self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=3, stride=2, padding=1) + self.block1 = WindowAttentionBlock(dim=embed_dim, num_heads=num_heads, window_size=window_size, shift_size=0) + self.block2 = WindowAttentionBlock( + dim=embed_dim, + num_heads=num_heads, + window_size=window_size, + shift_size=max(1, window_size // 2), + ) + + self.rnn = ConvLSTMCell(input_channels=embed_dim, hidden_channels=hidden_channels) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(hidden_channels, hidden_channels, kernel_size=2, stride=2), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + b, _, _, h, w = x_seq.shape + h2, w2 = (h + 1) // 2, (w + 1) // 2 + device = x_seq.device + + h_state = torch.zeros((b, self.hidden_channels, h2, w2), device=device) + c_state = torch.zeros((b, self.hidden_channels, h2, w2), device=device) + + for t in range(x_seq.shape[1]): + x_t = self.patch_embed(x_seq[:, t]) + x_t = self.block1(x_t) + x_t = self.block2(x_t) + h_state, c_state = self.rnn(x_t, h_state, c_state) + + return self.decoder(h_state) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_swinlstm_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: SwinLSTMTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinySwinLSTM( + in_channels=cfg.in_channels, + embed_dim=cfg.embed_dim, + hidden_channels=cfg.hidden_channels, + num_heads=cfg.num_heads, + window_size=cfg.window_size, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("SwinLSTM Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: SwinLSTMTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "swinlstm", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = SwinLSTMTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_swinlstm_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "swinlstm_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run SwinLSTM Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = Path(args.output_dir) if args.output_dir else base / f"swinlstm_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] swinlstm synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def swinlstm_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "swinlstm") + init_kwargs = filter_init_kwargs(TinySwinLSTM, {"in_channels": int(in_channels), **kwargs}) + model = TinySwinLSTM(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinySwinLSTM", "swinlstm_builder"] diff --git a/pyhazards/models/tcn.py b/pyhazards/models/tcn.py new file mode 100644 index 00000000..f56b8654 --- /dev/null +++ b/pyhazards/models/tcn.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class TCNTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + embed_dim: int = 16 + hidden_channels: int = 16 + kernel_size: int = 3 + num_levels: int = 3 + dropout: float = 0.1 + lr: float = 1e-3 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class Chomp1d(nn.Module): + def __init__(self, chomp_size: int): + super().__init__() + self.chomp_size = chomp_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.chomp_size <= 0: + return x + return x[:, :, :-self.chomp_size] + + +class TemporalBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, dilation: int, dropout: float): + super().__init__() + padding = (kernel_size - 1) * dilation + + self.net = nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation), + Chomp1d(padding), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding, dilation=dilation), + Chomp1d(padding), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + ) + self.downsample = nn.Conv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None + self.act = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.net(x) + residual = x if self.downsample is None else self.downsample(x) + return self.act(out + residual) + + +class TemporalConvNet(nn.Module): + def __init__(self, in_channels: int, hidden_channels: int, kernel_size: int, num_levels: int, dropout: float): + super().__init__() + layers: List[nn.Module] = [] + for i in range(num_levels): + dilation = 2 ** i + cin = in_channels if i == 0 else hidden_channels + cout = hidden_channels + layers.append(TemporalBlock(cin, cout, kernel_size=kernel_size, dilation=dilation, dropout=dropout)) + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class TinyTCN(nn.Module): + def __init__( + self, + in_channels: int = 1, + embed_dim: int = 16, + hidden_channels: int = 16, + kernel_size: int = 3, + num_levels: int = 3, + dropout: float = 0.1, + ): + super().__init__() + self.embed_dim = embed_dim + self.hidden_channels = hidden_channels + + self.frame_encoder = nn.Sequential( + nn.Conv2d(in_channels, embed_dim, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + self.temporal = TemporalConvNet( + in_channels=embed_dim, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + num_levels=num_levels, + dropout=dropout, + ) + + self.decoder = nn.Sequential( + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + b, t, c, h, w = x_seq.shape + + x = x_seq.reshape(b * t, c, h, w) + x = self.frame_encoder(x) + d = x.shape[1] + + x = x.reshape(b, t, d, h, w) + x = x.permute(0, 3, 4, 2, 1).contiguous() # [B,H,W,D,T] + x = x.reshape(b * h * w, d, t) # [BHW,D,T] + + x = self.temporal(x) + x_last = x[:, :, -1] # [BHW,HID] + + x_last = x_last.reshape(b, h, w, self.hidden_channels).permute(0, 3, 1, 2).contiguous() + return self.decoder(x_last) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_tcn_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: TCNTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyTCN( + in_channels=cfg.in_channels, + embed_dim=cfg.embed_dim, + hidden_channels=cfg.hidden_channels, + kernel_size=cfg.kernel_size, + num_levels=cfg.num_levels, + dropout=cfg.dropout, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("TCN Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: TCNTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "tcn", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = TCNTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_tcn_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "tcn_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run TCN Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = Path(args.output_dir) if args.output_dir else base / f"tcn_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] tcn synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def tcn_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "tcn") + init_kwargs = filter_init_kwargs(TinyTCN, {"in_channels": int(in_channels), **kwargs}) + model = TinyTCN(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyTCN", "tcn_builder"] diff --git a/pyhazards/models/ts_satfire.py b/pyhazards/models/ts_satfire.py new file mode 100644 index 00000000..9e9c28f6 --- /dev/null +++ b/pyhazards/models/ts_satfire.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import torch +import torch.nn as nn + + +class TSSatFire(nn.Module): + """Spatio-temporal wildfire prediction model inspired by TS-SatFire.""" + + def __init__( + self, + history: int = 5, + in_channels: int = 8, + hidden_dim: int = 32, + out_channels: int = 1, + dropout: float = 0.1, + ): + super().__init__() + if history <= 0: + raise ValueError(f"history must be positive, got {history}") + if in_channels <= 0: + raise ValueError(f"in_channels must be positive, got {in_channels}") + if hidden_dim <= 0: + raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") + if out_channels <= 0: + raise ValueError(f"out_channels must be positive, got {out_channels}") + if not 0.0 <= dropout < 1.0: + raise ValueError(f"dropout must be in [0, 1), got {dropout}") + + self.history = int(history) + self.in_channels = int(in_channels) + self.temporal_encoder = nn.Sequential( + nn.Conv3d(in_channels, hidden_dim, kernel_size=(3, 3, 3), padding=1), + nn.GELU(), + nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(3, 3, 3), padding=1), + nn.GELU(), + ) + self.time_attention = nn.Conv3d(hidden_dim, 1, kernel_size=1) + self.decoder = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(), + nn.Conv2d(hidden_dim, out_channels, kernel_size=1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.ndim != 5: + raise ValueError( + "TSSatFire expects input shape (batch, history, channels, height, width), " + f"got {tuple(x.shape)}." + ) + if x.size(1) != self.history: + raise ValueError(f"TSSatFire expected history={self.history}, got {x.size(1)}.") + if x.size(2) != self.in_channels: + raise ValueError(f"TSSatFire expected in_channels={self.in_channels}, got {x.size(2)}.") + + feat = self.temporal_encoder(x.permute(0, 2, 1, 3, 4)) + attn = torch.softmax(self.time_attention(feat), dim=2) + pooled = torch.sum(attn * feat, dim=2) + return self.decoder(pooled) + + +def ts_satfire_builder( + task: str, + history: int = 5, + in_channels: int = 8, + hidden_dim: int = 32, + out_channels: int = 1, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() not in {"segmentation", "regression"}: + raise ValueError(f"ts_satfire supports task='segmentation' or 'regression', got {task!r}.") + return TSSatFire( + history=history, + in_channels=in_channels, + hidden_dim=hidden_dim, + out_channels=out_channels, + dropout=dropout, + ) + + +__all__ = ["TSSatFire", "ts_satfire_builder"] diff --git a/pyhazards/models/unet.py b/pyhazards/models/unet.py new file mode 100644 index 00000000..8eb11a42 --- /dev/null +++ b/pyhazards/models/unet.py @@ -0,0 +1,482 @@ +from __future__ import annotations + +import argparse +import csv +import json +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + + +@dataclass +class UNetTrackOConfig: + in_channels: int = 1 + base_channels: int = 8 + lr: float = 1e-3 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + +class TinyUNet(nn.Module): + def __init__(self, in_channels: int = 1, base_channels: int = 16): + super().__init__() + c1, c2, c3 = base_channels, base_channels * 2, base_channels * 4 + + self.enc1 = ConvBlock(in_channels, c1) + self.pool1 = nn.MaxPool2d(kernel_size=2) + + self.enc2 = ConvBlock(c1, c2) + self.pool2 = nn.MaxPool2d(kernel_size=2) + + self.bottleneck = ConvBlock(c2, c3) + + self.up2 = nn.ConvTranspose2d(c3, c2, kernel_size=2, stride=2) + self.dec2 = ConvBlock(c2 + c2, c2) + + self.up1 = nn.ConvTranspose2d(c2, c1, kernel_size=2, stride=2) + self.dec1 = ConvBlock(c1 + c1, c1) + + self.head = nn.Conv2d(c1, 1, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.enc1(x) + x2 = self.enc2(self.pool1(x1)) + xb = self.bottleneck(self.pool2(x2)) + + y2 = self.up2(xb) + y2 = torch.cat([y2, x2], dim=1) + y2 = self.dec2(y2) + + y1 = self.up1(y2) + y1 = torch.cat([y1, x1], dim=1) + y1 = self.dec1(y1) + + return self.head(y1) + + +def binary_ece(y_true: np.ndarray, y_prob: np.ndarray, n_bins: int = 15) -> float: + bins = np.linspace(0.0, 1.0, n_bins + 1) + ece = 0.0 + n = float(len(y_true)) + for i in range(n_bins): + lo, hi = bins[i], bins[i + 1] + if i == n_bins - 1: + mask = (y_prob >= lo) & (y_prob <= hi) + else: + mask = (y_prob >= lo) & (y_prob < hi) + if not np.any(mask): + continue + acc = float(np.mean(y_true[mask])) + conf = float(np.mean(y_prob[mask])) + ece += (float(np.sum(mask)) / n) * abs(acc - conf) + return float(ece) + + +def normalized_consistency_score(mean_day_to_day_change: float) -> float: + return float(np.clip(1.0 - float(mean_day_to_day_change), 0.0, 1.0)) + + +def make_synthetic_fire_maps( + n_samples: int, + image_size: int, + seed: int, +) -> Tuple[np.ndarray, np.ndarray]: + rng = np.random.default_rng(seed) + yy, xx = np.meshgrid(np.arange(image_size), np.arange(image_size), indexing="ij") + + x = np.zeros((n_samples, 1, image_size, image_size), dtype=np.float32) + y = np.zeros((n_samples, 1, image_size, image_size), dtype=np.float32) + + for i in range(n_samples): + field = rng.normal(0.0, 0.15, size=(image_size, image_size)) + n_sources = int(rng.integers(1, 4)) + + for _ in range(n_sources): + cx = float(rng.uniform(0, image_size - 1)) + cy = float(rng.uniform(0, image_size - 1)) + sigma = float(rng.uniform(1.8, 4.8)) + amp = float(rng.uniform(0.8, 2.2)) + dist2 = (xx - cx) ** 2 + (yy - cy) ** 2 + field += amp * np.exp(-dist2 / (2.0 * sigma * sigma)) + + terrain = (yy / max(1, image_size - 1)) * rng.uniform(-0.15, 0.15) + wind = (xx / max(1, image_size - 1)) * rng.uniform(-0.25, 0.25) + + signal = field + terrain + wind + rng.normal(0.0, 0.08, size=(image_size, image_size)) + threshold = float(np.quantile(field, 0.90)) + mask = (field > threshold).astype(np.float32) + + x[i, 0] = signal.astype(np.float32) + y[i, 0] = mask + + x_mean = float(np.mean(x)) + x_std = float(np.std(x) + 1e-6) + x = (x - x_mean) / x_std + return x, y + + +def split_train_val_test( + x: np.ndarray, + y: np.ndarray, + seed: int, + train_ratio: float = 0.7, + val_ratio: float = 0.15, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + n = x.shape[0] + rng = np.random.default_rng(seed) + idx = rng.permutation(n) + + n_train = max(1, int(n * train_ratio)) + n_val = max(1, int(n * val_ratio)) + n_train = min(n_train, n - 2) + n_val = min(n_val, n - n_train - 1) + + train_idx = idx[:n_train] + val_idx = idx[n_train : n_train + n_val] + test_idx = idx[n_train + n_val :] + + return ( + x[train_idx], + y[train_idx], + x[val_idx], + y[val_idx], + x[test_idx], + y[test_idx], + ) + + +def _choose_device(device_text: str) -> torch.device: + normalized = str(device_text).strip().lower() + if normalized.startswith("cuda") and torch.cuda.is_available(): + return torch.device(device_text) + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_unet_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: UNetTrackOConfig, +): + if x_train.ndim != 4 or x_val.ndim != 4: + raise ValueError("x_train and x_val must be 4D arrays [N,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyUNet(in_channels=cfg.in_channels, base_channels=cfg.base_channels).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("U-Net Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: UNetTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "unet", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 192, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_maps(n_samples=n_samples, image_size=image_size, seed=seed) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = UNetTrackOConfig( + seed=seed, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cuda" if torch.cuda.is_available() else "cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_unet_track_o(x_train, y_train, x_val, y_val, cfg) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "unet_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run U-Net Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=192) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = Path(args.output_dir) if args.output_dir else base / f"unet_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] unet synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def unet_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "unet") + init_kwargs = filter_init_kwargs(TinyUNet, {"in_channels": int(in_channels), **kwargs}) + model = TinyUNet(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyUNet", "unet_builder"] diff --git a/pyhazards/models/utae.py b/pyhazards/models/utae.py new file mode 100644 index 00000000..2b5ea2ef --- /dev/null +++ b/pyhazards/models/utae.py @@ -0,0 +1,446 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class UTAETrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + hidden_channels: int = 16 + num_heads: int = 4 + lr: float = 3e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class TemporalSelfBlock(nn.Module): + def __init__(self, dim: int, num_heads: int): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) + self.norm2 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 2), + nn.GELU(), + nn.Linear(dim * 2, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = self.norm1(x) + y, _ = self.attn(y, y, y, need_weights=False) + x = x + y + x = x + self.ffn(self.norm2(x)) + return x + + +class TinyUTAE(nn.Module): + def __init__(self, in_channels: int = 1, hidden_channels: int = 16, num_heads: int = 4): + super().__init__() + self.hidden_channels = hidden_channels + + self.frame_encoder = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + self.temporal_block = TemporalSelfBlock(hidden_channels, num_heads) + + self.query_token = nn.Parameter(torch.zeros(1, 1, hidden_channels)) + self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_channels, num_heads=num_heads, batch_first=True) + self.decoder = nn.Sequential( + nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_channels, 1, kernel_size=1), + ) + + nn.init.normal_(self.query_token, mean=0.0, std=0.02) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + b, t, c, h, w = x_seq.shape + + x = x_seq.reshape(b * t, c, h, w) + x = self.frame_encoder(x) + d = x.shape[1] + + x = x.reshape(b, t, d, h, w) + x = x.permute(0, 3, 4, 1, 2).contiguous().reshape(b * h * w, t, d) # [BHW,T,D] + x = self.temporal_block(x) + + q = self.query_token.expand(x.shape[0], -1, -1) + agg, _ = self.cross_attn(q, x, x, need_weights=False) # [BHW,1,D] + agg = agg[:, 0, :] + + agg = agg.reshape(b, h, w, d).permute(0, 3, 1, 2).contiguous() + return self.decoder(agg) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_utae_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: UTAETrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyUTAE( + in_channels=cfg.in_channels, + hidden_channels=cfg.hidden_channels, + num_heads=cfg.num_heads, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("UTAE Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: UTAETrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "utae", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = UTAETrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_utae_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "utae_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run UTAE Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"utae_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] utae synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def utae_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "utae") + init_kwargs = filter_init_kwargs(TinyUTAE, {"in_channels": int(in_channels), **kwargs}) + model = TinyUTAE(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyUTAE", "utae_builder"] diff --git a/pyhazards/models/viirs_375m_active_fire.py b/pyhazards/models/viirs_375m_active_fire.py new file mode 100644 index 00000000..3d5ed6cc --- /dev/null +++ b/pyhazards/models/viirs_375m_active_fire.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class VIIRS375mActiveFire(nn.Module): + """Algorithm-inspired VIIRS 375 m active-fire detector with a learnable calibration head.""" + + def __init__( + self, + in_channels: int = 5, + hidden_dim: int = 24, + out_dim: int = 1, + context_kernel: int = 7, + dropout: float = 0.1, + ): + super().__init__() + if in_channels < 5: + raise ValueError( + "VIIRS375mActiveFire expects at least 5 channels: " + "mid_ir, long_ir, frp_proxy, clear_sky, dryness." + ) + if hidden_dim <= 0: + raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") + if out_dim <= 0: + raise ValueError(f"out_dim must be positive, got {out_dim}") + if context_kernel <= 1 or context_kernel % 2 == 0: + raise ValueError(f"context_kernel must be an odd integer > 1, got {context_kernel}") + if not 0.0 <= dropout < 1.0: + raise ValueError(f"dropout must be in [0,1), got {dropout}") + + self.in_channels = int(in_channels) + self.context_kernel = int(context_kernel) + + evidence_channels = self.in_channels + 4 + self.context_pool = nn.AvgPool2d(kernel_size=context_kernel, stride=1, padding=context_kernel // 2) + self.evidence_encoder = nn.Sequential( + nn.Conv2d(evidence_channels, hidden_dim, kernel_size=1), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + ) + self.calibration_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(), + nn.Conv2d(hidden_dim, out_dim, kernel_size=1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.ndim != 4: + raise ValueError( + "VIIRS375mActiveFire expects input shape (batch, channels, height, width), " + f"got {tuple(x.shape)}." + ) + if x.size(1) < 5: + raise ValueError(f"VIIRS375mActiveFire expected at least 5 channels, got {x.size(1)}.") + + x = x[:, : self.in_channels] + mid_ir = x[:, 0:1] + long_ir = x[:, 1:2] + frp_proxy = x[:, 2:3] + clear_sky = x[:, 3:4] + dryness = x[:, 4:5] + + local_background = self.context_pool(mid_ir) + thermal_excess = mid_ir - local_background + split_window = mid_ir - long_ir + fire_evidence = F.relu(thermal_excess) + 0.5 * F.relu(split_window) + confidence_gate = torch.sigmoid(clear_sky) * torch.sigmoid(dryness) + + evidence = torch.cat( + [ + x, + thermal_excess, + split_window, + fire_evidence, + confidence_gate + frp_proxy, + ], + dim=1, + ) + encoded = self.evidence_encoder(evidence) + return self.calibration_head(encoded) + + +def viirs_375m_active_fire_builder( + task: str, + in_channels: int = 5, + hidden_dim: int = 24, + out_dim: int = 1, + context_kernel: int = 7, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError( + f"viirs_375m_active_fire is segmentation-only in PyHazards, got task={task!r}." + ) + return VIIRS375mActiveFire( + in_channels=in_channels, + hidden_dim=hidden_dim, + out_dim=out_dim, + context_kernel=context_kernel, + dropout=dropout, + ) + + +__all__ = ["VIIRS375mActiveFire", "viirs_375m_active_fire_builder"] diff --git a/pyhazards/models/vit_segmenter.py b/pyhazards/models/vit_segmenter.py new file mode 100644 index 00000000..2d15866b --- /dev/null +++ b/pyhazards/models/vit_segmenter.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import argparse +import csv +import json +import sys +from copy import deepcopy +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import average_precision_score, brier_score_loss, log_loss, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + from pyhazards.models.mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) +else: + from .mau import ( + binary_ece, + make_synthetic_fire_sequences, + normalized_consistency_score, + split_train_val_test, + ) + + +@dataclass +class ViTSegmenterTrackOConfig: + seq_len: int = 6 + in_channels: int = 1 + patch_size: int = 4 + embed_dim: int = 64 + depth: int = 4 + num_heads: int = 4 + mlp_ratio: float = 2.0 + dropout: float = 0.1 + lr: float = 2e-4 + weight_decay: float = 1e-4 + batch_size: int = 8 + max_epochs: int = 120 + early_stopping_rounds: int = 16 + min_delta: float = 1e-4 + seed: int = 42 + pos_weight_clip_max: float = 50.0 + device: str = "cpu" + + +class TransformerBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, mlp_ratio: float, dropout: float): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True) + self.norm2 = nn.LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(int(dim * mlp_ratio), dim), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B,N,D] + y = self.norm1(x) + y, _ = self.attn(y, y, y, need_weights=False) + x = x + y + x = x + self.mlp(self.norm2(x)) + return x + + +class TinyViTSegmenter(nn.Module): + def __init__( + self, + in_channels: int = 1, + patch_size: int = 4, + embed_dim: int = 64, + depth: int = 4, + num_heads: int = 4, + mlp_ratio: float = 2.0, + dropout: float = 0.1, + ): + super().__init__() + self.patch_size = patch_size + + self.patch_embed = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + padding=0, + ) + self.blocks = nn.ModuleList( + [TransformerBlock(embed_dim, num_heads, mlp_ratio=mlp_ratio, dropout=dropout) for _ in range(depth)] + ) + self.norm = nn.LayerNorm(embed_dim) + + self.seg_head = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 2, 1, kernel_size=1), + ) + + def _temporal_fusion(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + frame_scores = x_seq.mean(dim=(2, 3, 4)) + weights = torch.softmax(frame_scores, dim=1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + return torch.sum(x_seq * weights, dim=1) + + def forward(self, x_seq: torch.Tensor) -> torch.Tensor: + # x_seq: [B,T,C,H,W] + _, _, _, h, w = x_seq.shape + x = self._temporal_fusion(x_seq) # [B,C,H,W] + + feat = self.patch_embed(x) # [B,D,Hp,Wp] + b, d, hp, wp = feat.shape + + tokens = feat.flatten(2).transpose(1, 2).contiguous() # [B,N,D] + for blk in self.blocks: + tokens = blk(tokens) + tokens = self.norm(tokens) + + feat = tokens.transpose(1, 2).reshape(b, d, hp, wp).contiguous() + logits_small = self.seg_head(feat) + return F.interpolate(logits_small, size=(h, w), mode="bilinear", align_corners=False) + + +def _choose_device(device_text: str) -> torch.device: + if device_text == "cuda" and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _build_loader(x: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool) -> DataLoader: + ds = TensorDataset( + torch.from_numpy(x.astype(np.float32)), + torch.from_numpy(y.astype(np.float32)), + ) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle) + + +def _predict_probabilities(model: nn.Module, loader: DataLoader, device: torch.device) -> np.ndarray: + probs: List[np.ndarray] = [] + model.eval() + with torch.no_grad(): + for xb, _ in loader: + xb = xb.to(device) + logits = model(xb) + p = torch.sigmoid(logits).detach().cpu().numpy() + probs.append(p) + if not probs: + return np.zeros((0,), dtype=np.float32) + return np.concatenate(probs, axis=0).reshape(-1) + + +def train_vit_segmenter_track_o( + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + cfg: ViTSegmenterTrackOConfig, +): + if x_train.ndim != 5 or x_val.ndim != 5: + raise ValueError("x_train and x_val must be 5D arrays [N,T,C,H,W]") + if y_train.ndim != 4 or y_val.ndim != 4: + raise ValueError("y_train and y_val must be 4D arrays [N,1,H,W]") + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = _choose_device(cfg.device) + + model = TinyViTSegmenter( + in_channels=cfg.in_channels, + patch_size=cfg.patch_size, + embed_dim=cfg.embed_dim, + depth=cfg.depth, + num_heads=cfg.num_heads, + mlp_ratio=cfg.mlp_ratio, + dropout=cfg.dropout, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + + total_px = float(y_train.size) + pos_px = float(np.sum(y_train)) + neg_px = max(1.0, total_px - pos_px) + raw_pos_weight = neg_px / max(pos_px, 1.0) + pos_weight = float(np.clip(raw_pos_weight, 1.0, cfg.pos_weight_clip_max)) + + criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device)) + + train_loader = _build_loader(x_train, y_train, batch_size=cfg.batch_size, shuffle=True) + val_loader = _build_loader(x_val, y_val, batch_size=cfg.batch_size, shuffle=False) + + history: List[Dict[str, float]] = [] + best_epoch = 1 + best_val_loss = float("inf") + best_state: Dict[str, torch.Tensor] | None = None + wait = 0 + + for epoch in range(1, cfg.max_epochs + 1): + model.train() + train_losses: List[float] = [] + + for xb, yb in train_loader: + xb = xb.to(device) + yb = yb.to(device) + + optimizer.zero_grad(set_to_none=True) + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + train_losses.append(float(loss.item())) + + model.eval() + val_losses: List[float] = [] + with torch.no_grad(): + for xb, yb in val_loader: + xb = xb.to(device) + yb = yb.to(device) + logits = model(xb) + loss = criterion(logits, yb) + val_losses.append(float(loss.item())) + + tr_loss = float(np.mean(train_losses)) if train_losses else float("nan") + va_loss = float(np.mean(val_losses)) if val_losses else float("nan") + + history.append( + { + "epoch": float(epoch), + "train_loss": tr_loss, + "val_loss": va_loss, + "learning_rate": float(optimizer.param_groups[0]["lr"]), + } + ) + + if va_loss < best_val_loss - cfg.min_delta: + best_val_loss = va_loss + best_epoch = epoch + best_state = deepcopy(model.state_dict()) + wait = 0 + else: + wait += 1 + + if wait >= cfg.early_stopping_rounds: + break + + if best_state is not None: + model.load_state_dict(best_state) + + val_prob = np.clip(_predict_probabilities(model, val_loader, device=device), 1e-7, 1.0 - 1e-7) + val_true = y_val.reshape(-1).astype(np.float32) + + mean_change = float(np.mean(np.abs(np.diff(np.sort(val_prob))))) if len(val_prob) > 1 else 0.0 + metrics = { + "auprc": float(average_precision_score(val_true, val_prob)), + "auroc": float(roc_auc_score(val_true, val_prob)), + "brier": float(brier_score_loss(val_true, val_prob)), + "nll": float(log_loss(val_true, val_prob)), + "ece": float(binary_ece(val_true, val_prob, n_bins=15)), + "mean_day_to_day_change": mean_change, + "normalized_consistency_score": normalized_consistency_score(mean_change), + } + + return model, history, metrics, best_epoch, pos_weight + + +def save_history_and_plot(history: List[Dict[str, float]], output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + history_csv = output_dir / "history.csv" + with history_csv.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "train_loss", "val_loss", "learning_rate"]) + writer.writeheader() + writer.writerows(history) + + x = [int(r["epoch"]) for r in history] + y_tr = [float(r["train_loss"]) for r in history] + y_va = [float(r["val_loss"]) for r in history] + + plt.figure(figsize=(8, 5)) + plt.plot(x, y_tr, label="train_bce", marker="o", linewidth=1.4) + plt.plot(x, y_va, label="val_bce", marker="s", linewidth=1.2) + plt.xlabel("epoch") + plt.ylabel("loss") + plt.title("ViT Segmenter Track-O: train loss vs epoch") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(output_dir / "loss_curve.png", dpi=150) + plt.close() + + +def build_experiment_setting( + cfg: ViTSegmenterTrackOConfig, + best_epoch: int, + pos_weight: float, + metrics: Dict[str, float], +) -> Dict[str, Any]: + return { + "benchmark": { + "task": "Track-O", + "model_name": "vit_segmenter", + "run_time": datetime.now().isoformat(), + }, + "evaluation_protocol": { + "discrimination": {"primary": "auprc", "secondary": "auroc"}, + "reliability": ["brier", "nll", "ece"], + "temporal_consistency": ["mean_day_to_day_change", "normalized_consistency_score"], + }, + "training": { + "train_unit": "epoch", + "max_epochs": cfg.max_epochs, + "early_stopping_rounds": cfg.early_stopping_rounds, + "best_epoch": best_epoch, + "seed": cfg.seed, + "batch_size": cfg.batch_size, + "seq_len": cfg.seq_len, + }, + "optimizer": { + "name": "AdamW", + "lr": cfg.lr, + "weight_decay": cfg.weight_decay, + }, + "learning_weight": { + "type": "pixel_pos_weight", + "value": pos_weight, + "clip_max": cfg.pos_weight_clip_max, + }, + "params": asdict(cfg), + "val_metrics": metrics, + "note": "This module supports both real data and synthetic smoke demonstration.", + } + + +def run_synthetic_demo( + output_dir: Path, + seed: int = 42, + n_samples: int = 160, + seq_len: int = 6, + image_size: int = 24, + max_epochs: int = 60, + early_stopping_rounds: int = 12, +) -> None: + x, y = make_synthetic_fire_sequences( + n_samples=n_samples, + seq_len=seq_len, + image_size=image_size, + seed=seed, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_train_val_test(x, y, seed=seed) + + cfg = ViTSegmenterTrackOConfig( + seed=seed, + seq_len=seq_len, + max_epochs=max_epochs, + early_stopping_rounds=early_stopping_rounds, + device="cpu", + ) + + model, history, val_metrics, best_epoch, pos_weight = train_vit_segmenter_track_o( + x_train, + y_train, + x_val, + y_val, + cfg, + ) + + test_loader = _build_loader(x_test, y_test, batch_size=cfg.batch_size, shuffle=False) + test_prob = np.clip(_predict_probabilities(model, test_loader, _choose_device(cfg.device)), 1e-7, 1.0 - 1e-7) + test_true = y_test.reshape(-1).astype(np.float32) + + test_mean_change = float(np.mean(np.abs(np.diff(np.sort(test_prob))))) if len(test_prob) > 1 else 0.0 + test_metrics = { + "auprc": float(average_precision_score(test_true, test_prob)), + "auroc": float(roc_auc_score(test_true, test_prob)), + "brier": float(brier_score_loss(test_true, test_prob)), + "nll": float(log_loss(test_true, test_prob)), + "ece": float(binary_ece(test_true, test_prob, n_bins=15)), + "mean_day_to_day_change": test_mean_change, + "normalized_consistency_score": normalized_consistency_score(test_mean_change), + } + + output_dir.mkdir(parents=True, exist_ok=True) + save_history_and_plot(history, output_dir) + + torch.save( + { + "state_dict": model.state_dict(), + "config": asdict(cfg), + "best_epoch": best_epoch, + }, + output_dir / "vit_segmenter_model.pt", + ) + + setting = build_experiment_setting(cfg, best_epoch=best_epoch, pos_weight=pos_weight, metrics=val_metrics) + setting["data"] = { + "n_samples": n_samples, + "image_size": image_size, + "seq_len": seq_len, + "split": {"train": int(x_train.shape[0]), "val": int(x_val.shape[0]), "test": int(x_test.shape[0])}, + } + setting["test_metrics"] = test_metrics + + (output_dir / "experiment_setting.json").write_text(json.dumps(setting, indent=2), encoding="utf-8") + (output_dir / "metrics.json").write_text( + json.dumps({"val": val_metrics, "test": test_metrics, "best_epoch": best_epoch}, indent=2), + encoding="utf-8", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run ViT Segmenter Track-O synthetic smoke demo") + parser.add_argument("--output_dir", default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_samples", type=int, default=160) + parser.add_argument("--seq_len", type=int, default=6) + parser.add_argument("--image_size", type=int, default=24) + parser.add_argument("--max_epochs", type=int, default=60) + parser.add_argument("--early_stopping_rounds", type=int, default=12) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + base = Path(__file__).resolve().parents[1] / "runs_scaffold" + out = ( + Path(args.output_dir) + if args.output_dir + else base / f"vit_segmenter_synthetic_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + + run_synthetic_demo( + output_dir=out, + seed=args.seed, + n_samples=args.n_samples, + seq_len=args.seq_len, + image_size=args.image_size, + max_epochs=args.max_epochs, + early_stopping_rounds=args.early_stopping_rounds, + ) + print(f"[done] vit_segmenter synthetic demo saved to: {out}") + + +if __name__ == "__main__": + main() + +from ._wildfire_benchmark_utils import SegmentationPort, filter_init_kwargs, require_task + + +def vit_segmenter_builder(task: str, in_channels: int = 1, out_dim: int = 1, **kwargs: Any) -> nn.Module: + require_task(task, {"segmentation"}, "vit_segmenter") + init_kwargs = filter_init_kwargs(TinyViTSegmenter, {"in_channels": int(in_channels), **kwargs}) + model = TinyViTSegmenter(**init_kwargs) + return SegmentationPort(model=model, out_channels=int(out_dim)) + + +__all__ = ["TinyViTSegmenter", "vit_segmenter_builder"] diff --git a/pyhazards/models/wildfire_aspp.py b/pyhazards/models/wildfire_aspp.py index cd3be885..d99ad13f 100644 --- a/pyhazards/models/wildfire_aspp.py +++ b/pyhazards/models/wildfire_aspp.py @@ -12,8 +12,8 @@ class WildfireASPP(WildfireCNNASPP): """ -def wildfire_aspp_builder(task: str, **kwargs) -> nn.Module: - return cnn_aspp_builder(task=task, **kwargs) +def wildfire_aspp_builder(*args, **kwargs) -> nn.Module: + return cnn_aspp_builder(*args, **kwargs) class TverskyLoss(nn.Module): diff --git a/pyhazards/models/wildfire_forecasting.py b/pyhazards/models/wildfire_forecasting.py deleted file mode 100644 index 087de17d..00000000 --- a/pyhazards/models/wildfire_forecasting.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -import torch -import torch.nn as nn - - -class WildfireForecasting(nn.Module): - """Sequence forecaster for weekly wildfire size-group activity.""" - - def __init__( - self, - input_dim: int = 7, - hidden_dim: int = 64, - output_dim: int = 5, - lookback: int = 12, - num_layers: int = 2, - dropout: float = 0.1, - ): - super().__init__() - if input_dim <= 0: - raise ValueError(f"input_dim must be positive, got {input_dim}") - if hidden_dim <= 0: - raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") - if output_dim <= 0: - raise ValueError(f"output_dim must be positive, got {output_dim}") - if lookback <= 0: - raise ValueError(f"lookback must be positive, got {lookback}") - if num_layers <= 0: - raise ValueError(f"num_layers must be positive, got {num_layers}") - if not 0.0 <= dropout < 1.0: - raise ValueError(f"dropout must be in [0, 1), got {dropout}") - - self.lookback = int(lookback) - self.encoder = nn.GRU( - input_size=input_dim, - hidden_size=hidden_dim, - num_layers=num_layers, - batch_first=True, - dropout=dropout if num_layers > 1 else 0.0, - ) - self.attention = nn.Linear(hidden_dim, 1) - self.head = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim, output_dim), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.ndim != 3: - raise ValueError( - "WildfireForecasting expects input shape (batch, lookback, features), " - f"got {tuple(x.shape)}." - ) - if x.size(1) != self.lookback: - raise ValueError( - f"WildfireForecasting expected lookback={self.lookback}, got sequence length {x.size(1)}." - ) - encoded, _ = self.encoder(x) - weights = torch.softmax(self.attention(encoded), dim=1) - pooled = torch.sum(weights * encoded, dim=1) - return self.head(pooled) - - -def wildfire_forecasting_builder( - task: str, - input_dim: int = 7, - hidden_dim: int = 64, - output_dim: int = 5, - lookback: int = 12, - num_layers: int = 2, - dropout: float = 0.1, - **kwargs, -) -> nn.Module: - _ = kwargs - if task.lower() not in {"forecasting", "regression"}: - raise ValueError( - "wildfire_forecasting supports task='forecasting' or 'regression', " - f"got {task!r}." - ) - return WildfireForecasting( - input_dim=input_dim, - hidden_dim=hidden_dim, - output_dim=output_dim, - lookback=lookback, - num_layers=num_layers, - dropout=dropout, - ) - - -__all__ = ["WildfireForecasting", "wildfire_forecasting_builder"] diff --git a/pyhazards/models/wildfiregpt.py b/pyhazards/models/wildfiregpt.py new file mode 100644 index 00000000..7deb48dc --- /dev/null +++ b/pyhazards/models/wildfiregpt.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class WildfireGPTReasoner(nn.Module): + """Retrieval-conditioned wildfire risk model inspired by the WildfireGPT multi-agent RAG system.""" + + def __init__( + self, + in_channels: int = 12, + out_dim: int = 1, + base_channels: int = 32, + hidden_dim: int = 64, + profile_dim: int = 8, + retrieved_dim: int = 16, + num_heads: int = 4, + dropout: float = 0.1, + ): + super().__init__() + if in_channels <= 0: + raise ValueError(f"in_channels must be positive, got {in_channels}") + if out_dim <= 0: + raise ValueError(f"out_dim must be positive, got {out_dim}") + if hidden_dim % num_heads != 0: + raise ValueError(f"hidden_dim={hidden_dim} must be divisible by num_heads={num_heads}") + + self.in_channels = int(in_channels) + self.profile_dim = int(profile_dim) + self.retrieved_dim = int(retrieved_dim) + self.hidden_dim = int(hidden_dim) + + self.raster_encoder = nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(base_channels, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + ) + self.profile_proj = nn.Linear(self.profile_dim, hidden_dim) + self.retrieved_proj = nn.Linear(self.retrieved_dim, hidden_dim) + self.raster_proj = nn.Linear(hidden_dim, hidden_dim) + + # Learned system-role tokens: user-profile, planner, analyst. + self.agent_tokens = nn.Parameter(torch.randn(3, hidden_dim) * 0.02) + self.attn = nn.MultiheadAttention( + embed_dim=hidden_dim, + num_heads=num_heads, + dropout=dropout, + batch_first=True, + ) + self.ffn = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, hidden_dim * 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim * 2, hidden_dim), + ) + self.decoder = nn.Sequential( + nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=3, padding=1), + nn.GELU(), + ) + self.head = nn.Conv2d(hidden_dim // 2, out_dim, kernel_size=1) + + def _unpack_inputs(self, inputs: torch.Tensor | Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + if isinstance(inputs, dict): + x = inputs.get("x") + profile = inputs.get("user_profile") + retrieved = inputs.get("retrieved_context") + else: + x = inputs + profile = None + retrieved = None + + if not isinstance(x, torch.Tensor): + raise ValueError("WildfireGPTReasoner expects a tensor input or a dict containing key 'x'.") + if x.ndim != 4: + raise ValueError( + "WildfireGPTReasoner expects input shape (B, C, H, W), " + f"got {tuple(x.shape)}." + ) + if x.size(1) != self.in_channels: + raise ValueError( + f"WildfireGPTReasoner expected in_channels={self.in_channels}, got {x.size(1)}." + ) + return x, profile, retrieved + + def _coerce_profile(self, profile: torch.Tensor | None, batch: int, device: torch.device) -> torch.Tensor: + if profile is None: + return torch.zeros(batch, self.profile_dim, device=device) + if profile.ndim != 2 or profile.size(0) != batch: + raise ValueError(f"user_profile must have shape (B,D), got {tuple(profile.shape)}") + if profile.size(1) == self.profile_dim: + return profile.to(device=device, dtype=torch.float32) + if profile.size(1) > self.profile_dim: + return profile[:, : self.profile_dim].to(device=device, dtype=torch.float32) + pad = torch.zeros(batch, self.profile_dim - profile.size(1), device=device) + return torch.cat([profile.to(device=device, dtype=torch.float32), pad], dim=1) + + def _coerce_retrieved(self, retrieved: torch.Tensor | None, batch: int, device: torch.device) -> torch.Tensor: + if retrieved is None: + return torch.zeros(batch, self.retrieved_dim, device=device) + if retrieved.ndim != 2 or retrieved.size(0) != batch: + raise ValueError(f"retrieved_context must have shape (B,D), got {tuple(retrieved.shape)}") + if retrieved.size(1) == self.retrieved_dim: + return retrieved.to(device=device, dtype=torch.float32) + if retrieved.size(1) > self.retrieved_dim: + return retrieved[:, : self.retrieved_dim].to(device=device, dtype=torch.float32) + pad = torch.zeros(batch, self.retrieved_dim - retrieved.size(1), device=device) + return torch.cat([retrieved.to(device=device, dtype=torch.float32), pad], dim=1) + + def forward(self, inputs: torch.Tensor | Dict[str, Any]) -> torch.Tensor: + x, profile, retrieved = self._unpack_inputs(inputs) + batch = x.size(0) + device = x.device + + feature_map = self.raster_encoder(x) + pooled_raster = F.adaptive_avg_pool2d(feature_map, 1).flatten(1) + + profile_token = self.profile_proj(self._coerce_profile(profile, batch, device)).unsqueeze(1) + retrieved_token = self.retrieved_proj(self._coerce_retrieved(retrieved, batch, device)).unsqueeze(1) + raster_token = self.raster_proj(pooled_raster).unsqueeze(1) + agent_tokens = self.agent_tokens.unsqueeze(0).expand(batch, -1, -1) + + tokens = torch.cat([agent_tokens, profile_token, retrieved_token, raster_token], dim=1) + attn_out, _ = self.attn(tokens, tokens, tokens, need_weights=False) + fused_tokens = attn_out + self.ffn(attn_out) + fused_global = fused_tokens.mean(dim=1) + + fused_map = fused_global.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, feature_map.size(-2), feature_map.size(-1)) + decoded = self.decoder(torch.cat([feature_map, fused_map], dim=1)) + return self.head(decoded) + + + +def wildfiregpt_builder( + task: str, + in_channels: int = 12, + out_dim: int = 1, + base_channels: int = 32, + hidden_dim: int = 64, + profile_dim: int = 8, + retrieved_dim: int = 16, + num_heads: int = 4, + dropout: float = 0.1, + **kwargs, +) -> nn.Module: + _ = kwargs + if task.lower() != "segmentation": + raise ValueError(f"wildfiregpt is segmentation-only in PyHazards, got task={task!r}.") + return WildfireGPTReasoner( + in_channels=in_channels, + out_dim=out_dim, + base_channels=base_channels, + hidden_dim=hidden_dim, + profile_dim=profile_dim, + retrieved_dim=retrieved_dim, + num_heads=num_heads, + dropout=dropout, + ) + + +__all__ = ["WildfireGPTReasoner", "wildfiregpt_builder"] diff --git a/pyhazards/models/wrf_sfire.py b/pyhazards/models/wrf_sfire.py index be71a950..08ac5035 100644 --- a/pyhazards/models/wrf_sfire.py +++ b/pyhazards/models/wrf_sfire.py @@ -6,7 +6,7 @@ class WRFSFireAdapter(nn.Module): - """Lightweight raster adapter inspired by WRF-SFIRE style spread diffusion.""" + """Lightweight raster adapter inspired by WRF-SFIRE spread transport.""" def __init__( self, @@ -21,8 +21,10 @@ def __init__( raise ValueError(f"WRFSFireAdapter only supports out_channels=1, got {out_channels}") if diffusion_steps <= 0: raise ValueError(f"diffusion_steps must be positive, got {diffusion_steps}") + self.in_channels = int(in_channels) self.diffusion_steps = int(diffusion_steps) + kernel = torch.tensor( [[0.02, 0.08, 0.02], [0.08, 0.60, 0.08], [0.02, 0.08, 0.02]], dtype=torch.float32, @@ -36,13 +38,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: f"got {tuple(x.shape)}." ) if x.size(1) != self.in_channels: - raise ValueError(f"WRFSFireAdapter expected in_channels={self.in_channels}, got {x.size(1)}.") + raise ValueError( + f"WRFSFireAdapter expected in_channels={self.in_channels}, got {x.size(1)}." + ) + + # The first three channels act as fireline, terrain, and moisture proxies. fireline = torch.sigmoid(x[:, :1]) terrain = torch.sigmoid(x[:, 1:2]) moisture = torch.sigmoid(x[:, 2:3]) + for _ in range(self.diffusion_steps): fireline = F.conv2d(fireline, self.transport_kernel, padding=1) - fireline = torch.clamp(fireline * (0.9 + 0.1 * terrain) * (1.0 - 0.15 * moisture), 0.0, 1.0) + fireline = torch.clamp( + fireline * (0.9 + 0.1 * terrain) * (1.0 - 0.15 * moisture), + 0.0, + 1.0, + ) return fireline diff --git a/pyhazards/models/xgboost.py b/pyhazards/models/xgboost.py new file mode 100644 index 00000000..39ebf252 --- /dev/null +++ b/pyhazards/models/xgboost.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Any, Optional + +import numpy as np +import torch.nn as nn + +from ._wildfire_benchmark_utils import EstimatorPort, filter_init_kwargs, require_task + + +class XGBoostModel(EstimatorPort): + """A boosted-tree wildfire occurrence baseline using a binary logistic objective.""" + + def __init__(self, max_depth: int = 8, eta: float = 0.05, subsample: float = 0.8, colsample_bytree: float = 0.8, num_boost_round: int = 800): + super().__init__() + self.params = { + "objective": "binary:logistic", + "eval_metric": "logloss", + "max_depth": int(max_depth), + "eta": float(eta), + "subsample": float(subsample), + "colsample_bytree": float(colsample_bytree), + } + self.num_boost_round = int(num_boost_round) + self.booster = None + + def _fit_numpy( + self, + x_train: np.ndarray, + y_train: np.ndarray, + x_val: Optional[np.ndarray], + y_val: Optional[np.ndarray], + ) -> None: + import xgboost as xgb + + dtrain = xgb.DMatrix(x_train, label=y_train) + evals = [(dtrain, "train")] + if x_val is not None and y_val is not None: + dval = xgb.DMatrix(x_val, label=y_val) + evals.append((dval, "val")) + self.booster = xgb.train( + params=self.params, + dtrain=dtrain, + num_boost_round=self.num_boost_round, + evals=evals, + verbose_eval=False, + ) + + def _predict_positive_proba(self, x: np.ndarray) -> np.ndarray: + if self.booster is None: + raise RuntimeError("XGBoost booster is not fitted.") + import xgboost as xgb + return np.asarray(self.booster.predict(xgb.DMatrix(x)), dtype=np.float32) + + +def xgboost_builder(task: str, **kwargs: Any) -> nn.Module: + require_task(task, {"classification"}, "xgboost") + build_kwargs = filter_init_kwargs(XGBoostModel, kwargs) + return XGBoostModel(**build_kwargs) + + +__all__ = ["XGBoostModel", "xgboost_builder"] diff --git a/scripts/align_wildfire_2024_fuel.py b/scripts/align_wildfire_2024_fuel.py new file mode 100644 index 00000000..52442120 --- /dev/null +++ b/scripts/align_wildfire_2024_fuel.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +from pyhazards.benchmarks.wildfire_benchmark.cache_builder import align_static_fuel_to_cache + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Align LANDFIRE fuel to the wildfire 2024 benchmark cache grid.") + parser.add_argument("--cache_dir", type=str, default="/home/runyang/my-copy/data_cache/wildfire_2024_v1") + parser.add_argument( + "--landfire_tif", + type=str, + default="/home/runyang/ryang/landfire_fbfm40/LF2024_FBFM13_250_CONUS/Tif/LC24_F13_250.tif", + ) + parser.add_argument("--overwrite", action="store_true") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + payload = align_static_fuel_to_cache( + cache_root=args.cache_dir, + landfire_tif=args.landfire_tif, + overwrite=bool(args.overwrite), + ) + print(f"[done] aligned fuel written under {Path(args.cache_dir) / 'static'}") + print(f"[summary] valid_cells={payload.get('valid_cells')} valid_fraction={payload.get('valid_fraction')}") + + +if __name__ == "__main__": + main() diff --git a/scripts/build_wildfire_2024_cache.py b/scripts/build_wildfire_2024_cache.py new file mode 100644 index 00000000..044b8bf7 --- /dev/null +++ b/scripts/build_wildfire_2024_cache.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from pyhazards.benchmarks.wildfire_benchmark.cache_builder import build_cache + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Build the wildfire 2024 real-data cache for benchmark runs.") + parser.add_argument( + "--config", + default=str(REPO_ROOT / "pyhazards" / "configs" / "wildfire_benchmark" / "cache_2024_v1.yaml"), + ) + parser.add_argument("--limit_days", type=int, default=0, help="Only materialize the first N shared dates for smoke-like validation.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + summary = build_cache(args.config, limit_days=int(args.limit_days)) + print("[done] wildfire cache built") + print(f"cache_root={summary.cache_root}") + print(f"label_days={summary.n_label_days} met_days={summary.n_met_days} shared_days={summary.n_shared_days}") + print(f"train={summary.train_days} val={summary.val_days} test={summary.test_days}") + print(f"weather_vars={','.join(summary.weather_vars)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_wildfire_2024_real_baselines.py b/scripts/run_wildfire_2024_real_baselines.py new file mode 100644 index 00000000..7962fb95 --- /dev/null +++ b/scripts/run_wildfire_2024_real_baselines.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +from pyhazards.benchmarks.wildfire_benchmark.real_runner import run_real_baselines + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run first real-data wildfire baselines on the 2024 cache.") + parser.add_argument("--cache_dir", type=str, default="/home/runyang/my-copy/data_cache/wildfire_2024_v1") + parser.add_argument("--run_name", type=str, default="track_o_2024_real_v1_first4_dryrun") + parser.add_argument("--models", type=str, default="logistic_regression,xgboost,unet,convlstm") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--train_limit_days", type=int, default=0) + parser.add_argument("--val_limit_days", type=int, default=0) + parser.add_argument("--test_limit_days", type=int, default=0) + parser.add_argument("--tabular_downsample", type=int, default=8) + parser.add_argument("--raster_downsample", type=int, default=4) + parser.add_argument("--temporal_downsample", type=int, default=8) + parser.add_argument("--temporal_history", type=int, default=6) + parser.add_argument("--xgboost_rounds", type=int, default=120) + parser.add_argument("--lightgbm_rounds", type=int, default=120) + parser.add_argument("--unet_epochs", type=int, default=12) + parser.add_argument("--convlstm_epochs", type=int, default=12) + parser.add_argument("--deep_patience", type=int, default=4) + parser.add_argument("--device", type=str, default=None) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + selected_models = [item.strip() for item in str(args.models).split(",") if item.strip()] + run_root = run_real_baselines( + cache_dir=args.cache_dir, + run_name=args.run_name, + models=selected_models, + seed=args.seed, + train_limit_days=args.train_limit_days or None, + val_limit_days=args.val_limit_days or None, + test_limit_days=args.test_limit_days or None, + tabular_downsample=args.tabular_downsample, + raster_downsample=args.raster_downsample, + temporal_downsample=args.temporal_downsample, + temporal_history=args.temporal_history, + xgboost_rounds=args.xgboost_rounds, + lightgbm_rounds=args.lightgbm_rounds, + unet_epochs=args.unet_epochs, + convlstm_epochs=args.convlstm_epochs, + deep_patience=args.deep_patience, + device=args.device, + ) + print(f"[done] real wildfire benchmark run written to {Path(run_root)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_wildfire_smoke_batch.py b/scripts/run_wildfire_smoke_batch.py new file mode 100644 index 00000000..5f810cfe --- /dev/null +++ b/scripts/run_wildfire_smoke_batch.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from pyhazards.benchmarks.wildfire_benchmark import run_smoke_batch +from pyhazards.benchmarks.wildfire_benchmark.adapters import create_adapter + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run wildfire benchmark smoke batches inside my-copy.") + parser.add_argument("--track", default="smoke", choices=["smoke", "real", "archive"]) + parser.add_argument("--run_name", default=None) + parser.add_argument("--catalog_kind", default="main", choices=["main", "extensions"]) + parser.add_argument("--catalog_path", default=None) + parser.add_argument("--contract_path", default=None) + parser.add_argument("--source_tier", default="all") + parser.add_argument("--models", default="") + parser.add_argument("--seeds", default="42") + parser.add_argument("--limit_models", type=int, default=0) + parser.add_argument("--max_epoch_steps", type=int, default=12) + parser.add_argument("--max_round_steps", type=int, default=30) + parser.add_argument("--max_iter_steps", type=int, default=20) + parser.add_argument("--max_tree_steps", type=int, default=20) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + selected_models = [item.strip() for item in args.models.split(",") if item.strip()] or None + run_root = run_smoke_batch( + adapter_factory=create_adapter, + run_name=args.run_name, + track=args.track, + catalog_kind=args.catalog_kind, + catalog_path=args.catalog_path, + contract_path=args.contract_path, + source_tier=args.source_tier, + models=selected_models, + seeds=args.seeds, + limit_models=args.limit_models, + step_limits={ + "epoch": int(args.max_epoch_steps), + "round": int(args.max_round_steps), + "iteration": int(args.max_iter_steps), + "tree": int(args.max_tree_steps), + }, + ) + print(f"[done] wildfire benchmark smoke batch saved to: {run_root}") + + +if __name__ == "__main__": + main()