Commits


guyang3532 authored and GitHub committed 341484e67c8
Embedding sparsity optimization (#16141) ### Description Optimize compute graph by eliminating padding in embedding. ### Motivation and Context The computation for padding in nodes after embedding is unnecessary and waste computation resources. This pr just add an Optimizer of PaddingElimination to check and eliminate the padding after embedding automatically by modifying the graph. ### Implementation: 1. Find and check embedding node in graph. 2. Iterate the subgraph afterward the embedding node and record all the input nodes and output nodes to this subgraph. 3. Insert 'Reshape + ShrunkenGather' to flatten each input node shape from [batch_size, seqlen, ...] to [valid_token_without_padding, ...], and insert 'GatherGrad + Reshape' to unflatten each output node shape from [valid_token_without_padding, ...] to [batch_size, seqlen, ...] --------- Co-authored-by: mindest <linminuser@gmail.com>