Commits


Thiago Crepaldi authored and GitHub committed 9817b8c8a76
Fix state_dict/checkpoint issue introduced by #4639 (#4984) https://github.com/microsoft/onnxruntime/pull/4639 changed the default behavior by removing optimizer state from state_dict/checkpoint APIs. The reason for the previous change was to allow models trained on ORT to be used for inference on PyTorch, which is an important feature. Due to the change aforementioned, when resuming training from a checkpoint, the optimizer would start with random weights, leading to a bad performance. This behavior would also cause reproducibility issues, as the optimizer wouldnt be able to resume from its previous state. This PR adds a boolean flag to state_dict/save_xheckpoint API that when True (default) it saves both model and optimizer state. When False, only the model state is kept.