Commits


Prathik Rao authored and GitHub committed 11ad2994512
Adds ATen fallback for scaled_dot_product_attention (#21107) ### Description <!-- Describe your changes. --> Introduces an ATen fallback for `torch.nn.functional.scaled_dot_product_attention`. This operator was introduced in torch 2.0 and, since then, has had many updates including the implementation of memory efficient attention for V100 machines. The current torchscript exporter exports a subgraph for attention which does not provide the same memory savings that PyTorch's memory efficient attention kernel provides. Allowing fallback to PyTorch ATen op for attention helps mitigate memory spike issues for models leveraging memory efficient attention. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Memory issues arose when integrating ONNX Runtime Training with AML Stable Diffusion. --------- Co-authored-by: root <prathikrao@microsoft.com>