Commits


pengwa authored and GitHub committed f6c81d8aca6
Introduce padding inspector in ORTModule (#14652) ### Introduce padding inspector in ORTModule In some Transformer-based LLM training recipes, high data sparsity is observed due to 1). token padding (to max sequence length), 2). labels contains many ignore_index for calculate loss. This PR introduces a switch to enable data sparsity inspection, which 1). in short term, can inform training users to use techniques like dynamic batching to amortize the issue. 2). in medium and longer term, also helps us (training team) to have better understanding what our training customers' models looks like from perspective of data sparsity (and potentially motivate us to improve with runtime). Here is an example of different data sparsity with same training model arch, same training input, but with different user models. **Low Embed Density, High Label Density Case - Sentence Classification** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/text-classification/run_glue.py --model_name_or_path roberta-large-openai-detector --task_name mnli --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir --output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137 --fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 60 | EMBED | input_ids | 1 | 35.21 % | 1442 | 4096 | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] | | 61 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 62 | EMBED | input_ids | 1 | 30.00 % | 1229 | 4096 | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] | | 63 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 64 | EMBED | input_ids | 1 | 26.73 % | 1095 | 4096 | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] | | 65 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 66 | EMBED | input_ids | 1 | 30.03 % | 1230 | 4096 | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] | | 67 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | | 68 | EMBED | input_ids | 1 | 31.62 % | 1295 | 4096 | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] | | 69 | LABEL | labels | -100 | 100.00 % | 32 | 32 | N/A | <<< ``` **High Embed Density, Low Label Density Case - masked language model** ` python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/language-modeling/run_mlm.py --model_name_or_path bert-base-uncased --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused ` ``` >>>Valid token/label density (e.g. valid/total) in passing 10 steps: | STEP | INPUT TYPE | INPUT NAME | PAD IDX | DENSITY | VALID TOKENS | TOTAL TOKENS | VALID TOKENS/BATCH | | 710 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 711 | LABEL | labels | -100 | 13.77 % | 564 | 4096 | N/A | | 712 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 713 | LABEL | labels | -100 | 14.48 % | 593 | 4096 | N/A | | 714 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 715 | LABEL | labels | -100 | 14.18 % | 581 | 4096 | N/A | | 716 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 717 | LABEL | labels | -100 | 14.53 % | 595 | 4096 | N/A | | 718 | EMBED | input_ids | 0 | 100.00 % | 4096 | 4096 | [512, 512, 512, 512, 512, 512, 512, 512] | | 719 | LABEL | labels | -100 | 15.31 % | 627 | 4096 | N/A | <<< ``` #### Next Step Let's see how we leverage the data sparsity for improvement. Optimizations on the way around compute optimizer wave 2: > Loss compute flops reduction. > Flatten/Unflatten embedding tokens to save compute flops.