Commits


Wei-Sheng Chin authored and GitHub committed 24f9c1afe3a
Distributed Expand (#18126) This PR implements DistributedExpand for llama 2. Representative Examples of DistributedExpand: - [shard on non-expanded axis] `input tensor (shape=[8, 1], spec=S[0]R, device_mesh=[0,1]) -> Expand(target_shape=[8, 2] -> output tensor (shape=[8, 2], spec=S[0]R, device_mesh=[0,1])` - [sharding expanded axis is invalid since it must have dim=1 and axis with dim=1 cannot be sharded] `input tensor (shape=[1, 8], spec=S[0]R, device_mesh=[0,1]) -> Expand(target_shape=[2, 8] -> output tensor (shape=[2, 8], spec=S[0]R, device_mesh=[0,1])` From those examples, we observe a few important behaviors. - The output sharding spec is always the same to the input sharding spec. - Expanding always happen on axis with dimension=1. Otherwise, it will violate the broadcasting rule. - No communication is needed since all computation can happen locally. Let's consider the first example again. If you put the first half tensor (shape: [4, 1]) on device 0 and the second half (shape: [4, 1]) on device 1, then `Expand` it with target shape [4, 2] , these two local tensors (shape: [4, 2]) are exactly the same as the one described by output sharding spec. Algorithm: - Compute logical (i.e., unsharded) shapes of input and output. - Compute sharded output shape from logical output. - Call Expand to broadcast local input to sharded output shape. How to review? - Start with [changes in onnxruntime_test_distributed.py](https://github.com/microsoft/onnxruntime/pull/18126/commits/ea33392f375afd8e95d29bd5b1a403192ed3bebc). Those tests are good examples for using this op. - [Read expand.h/expand.cc](https://github.com/microsoft/onnxruntime/pull/18126/commits/e4c49987f5a09e19527248adcc197b7d4a695636). Theose changes are for exposing functionalities in Expand to DistributedExpand. - Read distributed_expand.h/distributed_expand.cc. It follows the algorithm described above. The commit https://github.com/microsoft/onnxruntime/pull/18126/commits/68ac301bbaff44d08168ac9049161a4d428b3c3d first sketches the definition of DistributedExpand. The next commit https://github.com/microsoft/onnxruntime/pull/18126/commits/0eb9330c3ba836911932444caca7fec0cbdad222 adds real implementation.