diff --git a/internnav/model/basemodel/navdp/navdp_policy.py b/internnav/model/basemodel/navdp/navdp_policy.py index 316243e7..22e03b7e 100644 --- a/internnav/model/basemodel/navdp/navdp_policy.py +++ b/internnav/model/basemodel/navdp/navdp_policy.py @@ -83,8 +83,10 @@ def __init__(self, config: NavDPModelConfig): self.token_dim = self.config.model_cfg['il']['token_dim'] self.scratch = self.config.model_cfg['il']['scratch'] self.finetune = self.config.model_cfg['il']['finetune'] + _da_ckpt = self.config.model_cfg['il'].get('depth_anything_checkpoint') self.rgbd_encoder = RGBDBackbone( - self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device + self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device, + **({'checkpoint': _da_ckpt} if _da_ckpt else {}), ) self.pixel_encoder = PixelGoalBackbone( self.image_size, self.token_dim, pixel_channel=self.pixel_channel, device=self._device