Commits


kailums authored and GitHub committed 1a294609196
rope support 4D input tensor (#18454) ### Description <!-- Describe your changes. --> change RotaryEmbeddings op implementation, add support for 4D input tensor that is with shape of [batch, num_heads, seq_len, head_size]. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Current RotaryEmbedding op only support 3d input tensor with shape [batch, seq_len, hidden_size] For llamav2 model, when using FusionRotaryEmbeddings to only fuse RotaryEmbeddings op, there will be a transpose operation for query and key, and then the input tensor of RotaryEmbeddings becomes 4D [batch, num_heads, seq_len, head_size]. This scenario can't be supported by current RotaryEmbeddings implementation. So it needs to support 4D input tensor.