torch.vmap error in transformers 4.57.1

(py312dbg) user@c28097a3bac8:/workspace/user/llmbenchprofiler$ torchrun --nnodes 1 --nproc_per_node 1 main.py -m GPT2_355m -s 1

[2025-11-20 06:45:05,275] [WARNING] [lr_schedules.py:690:get_lr] Attempting to get learning rate from scheduler before it has started
  0%|                                                                                                                                                                                       | 0/1 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/user/llmbenchprofiler/main.py", line 107, in <module>
[rank0]:     outputs = model(**batch)
[rank0]:               ^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2179, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1829, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1068, in forward
[rank0]:     transformer_outputs = self.transformer(
[rank0]:                           ^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 873, in forward
[rank0]:     causal_mask = create_causal_mask(
[rank0]:                   ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/masking_utils.py", line 825, in create_causal_mask
[rank0]:     causal_mask = mask_interface(
[rank0]:                   ^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/masking_utils.py", line 392, in sdpa_mask_recent_torch
[rank0]:     causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank0]:     return vmap_impl(
[rank0]:            ^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank0]:     return _flat_vmap(
[rank0]:            ^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank0]:     batched_outputs = func(*batched_inputs, **kwargs)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank0]:     return vmap_impl(
[rank0]:            ^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank0]:     return _flat_vmap(
[rank0]:            ^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank0]:     batched_outputs = func(*batched_inputs, **kwargs)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank0]:     return vmap_impl(
[rank0]:            ^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank0]:     return _flat_vmap(
[rank0]:            ^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank0]:     batched_outputs = func(*batched_inputs, **kwargs)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank0]:     return vmap_impl(
[rank0]:            ^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank0]:     return _flat_vmap(
[rank0]:            ^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank0]:     batched_outputs = func(*batched_inputs, **kwargs)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/transformers/masking_utils.py", line 54, in and_mask
[rank0]:     result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 146, in __torch_function__
[rank0]:     return func(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Batching rule not implemented for aten::to.dtype_layout; the fallback path doesn't work on out= or view ops.
  0%|                                                                                                                                                                                       | 0/1 [00:00<?, ?it/s]
E1120 06:45:06.741000 68171 torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 0 (pid: 68232) of binary: /workspace/usr/py312dbg/bin/python3
Traceback (most recent call last):
  File "/workspace/usr/py312dbg/bin/torchrun", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/run.py", line 936, in main
    run(args)
  File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/run.py", line 927, in run
    elastic_launch(
  File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 156, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/usr/py312dbg/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
main.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-11-20_06:45:06
  host      : c28097a3bac8
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 68232)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================