Commits


pengwa authored and GitHub committed 1150b1f81ea
ORTModule memory improvement (#18924) ## Dependency https://github.com/microsoft/onnxruntime/pull/19007 ## ORTModule memory efficient gradient management Previously I have tried to solve the coarsed-grained gradient accumulation/update problem in ORTModule with https://github.com/microsoft/onnxruntime/pull/8979, while that resolution somehow is not fully validated with DDP or there is user hooks on the gradient accumulation on torch parameter. This PR is addressing the problem in the similar approach as PR 8979, e.g. trigger gradient accumulation once ORT computed the grad, but instead of use a AccumulateGrad op, this time with a ONNX operator PythonOp, internally it will call param.backward(grad), which will help handle all related hooks correctly. ## Design Check the details from https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ ## Convergence Validation:  differences are on mostly 0.000x, sometimes 0.00x, which may comes from the different order gradient apply happens before or after this change (on deepspeed zero stage 2) ## TODO Consolidate the logic with Stage3's similar logic.