Commits


pengwa authored and GitHub committed 5d8ce817cb5
Fix simplified layer norm fusion for training (#14866) ### Fix simplified layer norm fusion for training Co-author with @prathikr. Fix bug identified by @prathikr. https://github.com/microsoft/onnxruntime/issues/14822. Running T5 model enabling deepspeed, we see simplified layer norm is not fused because the device check did not pass https://github.com/microsoft/onnxruntime/blob/b7fde84341f5e7e4fc8b202e9aabad4d087ec15c/onnxruntime/core/optimizer/layer_norm_fusion.cc#L568. Since during pretraining optimization pass, there is no device placement, so the device check not fulfilled is expected. On the other hand, the device check is still valid to avoid simplified layer norm fusion works correctly for CPU runs. As a mitigation, added a flag to indicate whether the fusion is triggered by pre-training optimization or not. There is a risk though, when we run ORTModule training with CPU EP, but I feel the risk can be much reduced if we check CUDA/ROCM is enabled for the build. ``` CUDA_VISIBLE_DEVICES=0 python examples/onnxruntime/training/summarization/run_summarization.py --model_name_or_path t5-small --do_train --dataset_name cnn_dailymail --dataset_config "3.0.0" --source_prefix "summarize: " --predict_with_generate --overwrite_output_dir --output_dir /bert_ort/pengwa/output --fp16 --max_steps 1 --logging_steps 1 --deepspeed aml_ds_config_zero_1.json ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->