Skip to content

fix: use shape index access in compute_3d_position_ids for Qwen VL models#44921

Open
s-zx wants to merge 1 commit intohuggingface:mainfrom
s-zx:fix/qwen3-5-compute-3d-position-ids
Open

fix: use shape index access in compute_3d_position_ids for Qwen VL models#44921
s-zx wants to merge 1 commit intohuggingface:mainfrom
s-zx:fix/qwen3-5-compute-3d-position-ids

Conversation

@s-zx
Copy link

@s-zx s-zx commented Mar 22, 2026

What does this PR do?

Fixes #44918.

compute_3d_position_ids in the Qwen2.5-VL / Qwen3-VL / Qwen3.5 model families destructures inputs_embeds.shape into exactly three variables:

batch_size, seq_length, _ = inputs_embeds.shape

This raises ValueError: too many values to unpack (expected 3) when inputs_embeds has more than three dimensions, which can happen when:

  • The TRL SFT Trainer passes inputs_embeds directly (without input_ids) after processing, producing a batch with an extra dimension
  • stale rope_deltas from a preceding generation step (e.g. during evaluation) causes the elif branch to fire on a subsequent training forward pass

Fix

Replace destructuring with explicit index access:

batch_size, seq_length = inputs_embeds.shape[0], inputs_embeds.shape[1]

This is robust to any tensor with ≥ 2 dimensions and does not change behaviour for the standard 3-D case.

Affected files

  • models/qwen2_5_vl/modular_qwen2_5_vl.py — source modular definition
  • models/qwen2_5_vl/modeling_qwen2_5_vl.py — generated file
  • models/qwen3_vl/modeling_qwen3_vl.py — generated file (inherits the method)
  • models/qwen3_5/modeling_qwen3_5.py — generated file (inherits the method)

…dels

batch_size, seq_length, _ = inputs_embeds.shape raises ValueError when
inputs_embeds has more than three dimensions (e.g. when TRL SFT trainer
passes inputs_embeds directly, or after NEFTune noise injection adds a
leading dimension). Switching to explicit index access makes the code
robust to any number of dimensions >= 2.

Fixes: transformers#44918
Affected models: Qwen2.5-VL, Qwen3-VL, Qwen3.5
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen2_5_vl, qwen3_5, qwen3_vl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Unpacking Qwen3.5 input embeddings fails with trl SFT trainer

2 participants