Commits


Wei-Sheng Chin authored and GitHub committed 9e8ad398479
Distributed Reduction (#18206) This PR implements distributed reduciton for llama 2. This version doesn't consider any cases requring re-sharding because we haven't seen any use cases. Intutive examples: - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[0]) -> [1,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[1]) -> [2,1,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [not supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[2]) -> [2,4,1]-tensor with spec=RRS[0] and device_mesh=[0,1] Algorithm: When the reduced axes are not sharded, each device can call reduction directly. The output sharding spec will be identical to input sharding spec. We currently throw when input and output sharding specs are different. Review guideline: - Check 97b8d2f for new op's schema and how new op is registered. - Read tests in 2450f93 to get faimilar with the behavior of these ops. - Check the implementation details in 753d9af.