Skip to content

Fix RuntimeError by ensuring attn_mask is None when is_causal=True#1

Open
IgorAherne wants to merge 2 commits into
RobertAgee:optimized-chunkingfrom
IgorAherne:attention-mask-correction
Open

Fix RuntimeError by ensuring attn_mask is None when is_causal=True#1
IgorAherne wants to merge 2 commits into
RobertAgee:optimized-chunkingfrom
IgorAherne:attention-mask-correction

Conversation

@IgorAherne
Copy link
Copy Markdown

Prevents an error by modifying the Attention module: avoid passing an explicit attn_mask to F.scaled_dot_product_attention when is_causal=True.

This allows the function's internal causal masking mechanism to be used without conflict, during operations like decoder prefill.

Otherwise getting the following error:

Error processing task 1 ('C:_myDrive\repos\auto-vlog\AutoVlogProj\temp_video_processing\audio_maker_temp\girl and a stick in a forest\chunk_2.wav'): _scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True
Traceback (most recent call last):
File "C:_myDrive\repos\auto-vlog\dia\cli.py", line 168, in process_task
output_audio = model.generate(
^^^^^^^^^^^^^^^
File "C:_myDrive\repos\auto-vlog\venv\Lib\site-packages\torch\utils_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "C:_myDrive\repos\auto-vlog\dia\dia\model.py", line 407, in generate
dec_state, dec_output = self._prepare_generation(full_text, audio_prompt, verbose)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:_myDrive\repos\auto-vlog\dia\dia\model.py", line 289, in _prepare_generation
self.model.decoder.forward(tokens_BxTxC, dec_state)
File "C:_myDrive\repos\auto-vlog\dia\dia\layers.py", line 605, in forward
x = layer(x, state, self_attn_cache=self_cache, cross_attn_cache=cross_cache, prefill=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:_myDrive\repos\auto-vlog\venv\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:_myDrive\repos\auto-vlog\venv\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:_myDrive\repos\auto-vlog\dia\dia\layers.py", line 437, in forward
sa_out = self.self_attention(
^^^^^^^^^^^^^^^^^^^^
File "C:_myDrive\repos\auto-vlog\venv\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:_myDrive\repos\auto-vlog\venv\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:_myDrive\repos\auto-vlog\dia\dia\layers.py", line 251, in forward
attn_output = F.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: _scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True
Skipping task due to error.

@IgorAherne
Copy link
Copy Markdown
Author

Hm, looking at my proposed changes, there is more than necessary somehow.

I wanted to only offer this change:

        # If is_causal is True, sdpa handles masking internally, so attn_mask should be None.
        # Otherwise, use the provided attn_mask (e.g., for padding in encoder).
        effective_attn_mask = attn_mask if not is_causal else None

        attn_output = F.scaled_dot_product_attention(
            Xq_BxNxTxH,
            attn_k,
            attn_v,
            attn_mask=attn_mask,
            attn_mask=effective_attn_mask,     #use it here
            scale=1.0,
            enable_gqa=self.num_gqa_groups > 1,
            is_causal=is_causal,

@IgorAherne
Copy link
Copy Markdown
Author

ok, much less intrusive now, see Files Changed tab

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant