Commits


pengwa authored and GitHub committed 7bec80d92a1
Fix reference count for autograd.Function (#15121) ### Fix reference count for autograd When PythonOp kernel initialized, `AddPointerScalarArgs` creates `const_args_` which put all non-tensor references (including ProcessGroup, string, or other user types) in it. In kernel's destructor, all ref cnt got decreased for `const_args_`. ``` void PythonOpBase::Clear() { for (auto ptr : const_args_) { auto obj = reinterpret_cast<PyObject*>(ptr); Py_DECREF(obj); } } ``` It means, we did not increase cnt, but just decrease cnt. Running the unit, segmentation fault will be thrown. The simple fix is to remove the Py_DECREF for those pointer-type constant inputs triggered by kernel destructor. NONTENSOR_OBJECT_POINTER_STORE is the place we increase the reference during export, then the reference will remain until the python program terminates. Additionally tunings: 1. Move some logs into verbose instead of warning in case of flooding training logs. 2. Move pointer type ref holding from python side (NONTENSOR_OBJECT_POINTER_STORE) to orttraining/orttraining/core/framework/torch/custom_function_register.h. Then we use a consistent approach to manage all PythonOp related python object/methonds ref count increasing and decreasing.