Commits


Tianlei Wu authored and GitHub committed a6c5e2cd20d
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420) ### Description - [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe. - [x] Add multi-threading tests Previously, the kernel parameters params is stored as a member of mha runner, which means that different threads might change the params at the same time and impacts the other threads. For example, if batch_size and seq_len was changed by another thread to larger values in setup(...), buffer overrun might happen in run(...) because a kernel could read/write memory out of range of allocated buffers. In new implementation, I change the api and remove mutable member variables to make it thread safe. Below is summary of change: Before: ``` class FusedMHARunnerFP16v2::mhaImpl { void setup(int seq_len, int batch_size) { // change scalar params } void run(input, output) { // change params for input and output pointers // launch kernel using params } Fused_multihead_attention_params_v2 params; // mutable, not thread-safe } ``` After: ``` class FusedMHARunnerFP16v2::FmhaImpl { void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) { // change params } void run(params, input, output) { // change params with input and output pointers // launch kernel using params } } ``` ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/18854 https://github.com/microsoft/onnxruntime/issues/21413