torch.vmap error in transformers 4.57.1
(py312dbg) user@c28097a3bac8:/workspace/user/llmbenchprofiler$ torchrun --nnodes 1 --nproc_per_node 1 main.py -m GPT2_355m -s 1
[2025-11-20 06:45:05,275] [WARNING] [lr_schedules.py:690:get_lr] Attempting to get learning rate from scheduler before it has started
0%| | 0/1 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/user/llmbenchprofiler/main.py", line 107, in <module>
[rank0]: outputs = model(**batch)
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2179, in forward
[rank0]: loss = self.module(*inputs, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank0]: return inner()
[rank0]: ^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1829, in inner
[rank0]: result = forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1068, in forward
[rank0]: transformer_outputs = self.transformer(
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 873, in forward
[rank0]: causal_mask = create_causal_mask(
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/masking_utils.py", line 825, in create_causal_mask
[rank0]: causal_mask = mask_interface(
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/masking_utils.py", line 392, in sdpa_mask_recent_torch
[rank0]: causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank0]: return vmap_impl(
[rank0]: ^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank0]: return _flat_vmap(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank0]: batched_outputs = func(*batched_inputs, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank0]: return vmap_impl(
[rank0]: ^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank0]: return _flat_vmap(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank0]: batched_outputs = func(*batched_inputs, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank0]: return vmap_impl(
[rank0]: ^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank0]: return _flat_vmap(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank0]: batched_outputs = func(*batched_inputs, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank0]: return vmap_impl(
[rank0]: ^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank0]: return _flat_vmap(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank0]: batched_outputs = func(*batched_inputs, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/masking_utils.py", line 54, in and_mask
[rank0]: result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 146, in __torch_function__
[rank0]: return func(*args, **(kwargs or {}))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Batching rule not implemented for aten::to.dtype_layout; the fallback path doesn't work on out= or view ops.
0%| | 0/1 [00:00<?, ?it/s]
E1120 06:45:06.741000 68171 torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 0 (pid: 68232) of binary: /workspace/usr/py312dbg/bin/python3
Traceback (most recent call last):
File "/workspace/usr/py312dbg/bin/torchrun", line 8, in <module>
sys.exit(main())
^^^^^^
File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/run.py", line 936, in main
run(args)
File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/run.py", line 927, in run
elastic_launch(
File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 156, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
main.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2025-11-20_06:45:06
host : c28097a3bac8
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 68232)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================