Commits


pengwa authored and GitHub committed 1d322855363
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools.  ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.