Commits


kunal-vaishnavi authored and GitHub committed 2a17d5cf329
LLaMA Model Optimization (#18021) ### Description This PR contains fusion-level and kernel-level optimizations for [Meta's LLaMA-2](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/). Some of the added optimizations include: - SimplifiedLayerNorm changes - Fusions for multiple variants - SkipSimplifiedLayerNorm changes - Kernel support for CPU - Rotary embeddings (previously did not exist) - Fusions for multiple variants - CPU and CUDA kernels - Supports interleaving and non-interleaving in the same kernels - Optimized cache that requires half of its originally exported sizes - Reduced from `(max_sequence_length, head_size)` to `(max_sequence_length, head_size / 2)` - Multi-head attention - Support for 2D and 3D attention masks - Group query attention (for FP16 CUDA and INT4 CUDA) - Integration with flash attention v2 and past-present buffer sharing - Removes need for `attention_mask` input as it is supported in the kernel - 4 bit quantization - `block_size` parameter is available for customizing - Support the new changes for [Microsoft version](https://github.com/microsoft/Llama-2-Onnx) - Support combinations of the below variants (ex: export ORT version and run with Optimum) Supported variants of LLaMA-2 include: - [ORT version](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama) - Produces one ONNX file that is already optimized (and quantized if requested) - Integrates with Optimum - [Another Microsoft version](https://github.com/microsoft/Llama-2-Onnx) - Already exported and available off-the-shelf - Faster versions of those models will be uploaded there soon - [Hugging Face version](https://huggingface.co/meta-llama) - Models that end with `-hf` - Some older and current versions of [`transformers`](https://github.com/huggingface/transformers) and [`optimum`](https://github.com/huggingface/optimum) that export the model to ONNX differently - Note that while some older versions are supported, it is recommended to use the latest package versions. ### Usage To use the optimizations, please see `README.md` for details. Please note the various `requirements.txt` files for the package versions recommended in order to use these changes. To run the ORT transformer optimizer separately, run the script as follows: ``` $ cd onnxruntime/onnxruntime/python/tools/transformers/ $ python3 optimizer.py --input <filename>.onnx --output <filename>.onnx --model_type gpt2 --num_heads <number of attention heads> --hidden_size <attention hidden size> --use_external_data_format --opt_level 0 ``` ### Motivation and Context This PR helps the following issues: - https://github.com/microsoft/onnxruntime/issues/14997 - https://github.com/microsoft/onnxruntime/issues/16254 - https://github.com/microsoft/onnxruntime/issues/17681 - https://github.com/microsoft/onnxruntime/issues/17925 - https://github.com/microsoft/onnxruntime-inference-examples/issues/320 This PR uses changes from the following PRs: - https://github.com/pytorch/pytorch/pull/104468 - https://github.com/pytorch/pytorch/pull/109759 - https://github.com/microsoft/onnxruntime/pull/17020 - https://github.com/microsoft/onnxruntime/pull/17674 - https://github.com/microsoft/onnxruntime/pull/17890 - https://github.com/microsoft/onnxruntime/pull/17920 - https://github.com/huggingface/transformers/pull/26162 - https://github.com/huggingface/optimum/pull/1257 - https://github.com/huggingface/optimum/pull/1289 - https://github.com/huggingface/optimum/pull/1462 ### New TorchDynamo Exporter (experimental stage) This PR uses changes from the following issues and PRs to begin supporting the [new TorchDynamo exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter): - https://github.com/huggingface/transformers/pull/26307 - https://github.com/pytorch/pytorch/issues/104903 - https://github.com/pytorch/pytorch/pull/105040 - https://github.com/microsoft/onnxscript/pull/847 - https://github.com/microsoft/onnxscript/pull/862 - https://github.com/microsoft/onnxscript/issues/493