Commits


Jambay Kinley authored and GitHub committed 1af06815540
Bfloat16 support for MatMulBnb4, Training support bitsandbytes>=0.41.2 (#18484) ### Description <!-- Describe your changes. --> Add bfloat16 support for `MatMulBnb4` contrib op. This is useful for QLoRA fine-tuning. - On GPUs with SM80+ (A100, etc), it uses the native cuda bfloat16 dtype, `nv_bfloat16`. On other GPUs, it uses the onnxruntime `BFloat16` type which uses float for compute. - I have validated the op in a llama2-7b training scenario. The losses match pytorch training and the training throughput is better. - Cannot add a bfloat16 case in the op unit test since casting BFloat16 to and from float multiple times during the test causes the required tolerances to be unachievable. The custom autograd function exporter in onnxruntime-training is updated to support the latest version of bitsandbytes. They changed how the `quant_state` is stored. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Enable QLoRA fine-tuning with bfloat16.