Commits


Ivan Komarov authored and GitHub committed 16b39e5b87b
`symbolic_shape_infer.py`: Fix slicing a tensor that has a sympy.Min() in its shape (#14384) ### Description `_infer_Slice()` is a function (arguably the most complex one) in `symbolic_shape_infer.py` that infers the shape of the output of a `Slice` node. This commit fixes an edge case in `_infer_Slice()` caused by a SymPy quirk. When both the end of the slice (let's call it `e`) and the corresponding dimension of the sliced tensor (let's call it `dim`) are arbitrary symbolic expressions, `symbolic_shape_infer.py` [checks](https://github.com/microsoft/onnxruntime/blob/de7a868d5f3390d7c095a53c26abd39f402f3f93/onnxruntime/python/tools/symbolic_shape_infer.py#L1728) if `e <= dim`. Comparing symbolic expressions is hard in general, so if the comparison fails, `symbolic_shape_infer.py` [gives up](https://github.com/microsoft/onnxruntime/blob/de7a868d5f3390d7c095a53c26abd39f402f3f93/onnxruntime/python/tools/symbolic_shape_infer.py#L1734) and assumes that `e` is equal to `dim`. A failure of this sort currently happens for expressions of the form `Y - X >= 0` where `Y` contains a `sympy.Min()` (`symbolic_shape_infer.py` tries to rewrite `X <= Y` comparisons in various ways, and `Y - X >= 0` is [one of them](https://github.com/microsoft/onnxruntime/blob/de7a868d5f3390d7c095a53c26abd39f402f3f93/onnxruntime/python/tools/symbolic_shape_infer.py#L1664)). An simple example to illustrate this: ```python >>> import sympy >>> X = sympy.Symbol('X', positive=True, integer=True) >>> >>> y1 = 9999 >>> Y1 = X + y1 - 5000 >>> bool(Y1 - X >= 0) True >>> >>> y2 = X + 4999 >>> Y2 = X + y2 - 5000 >>> bool(Y2 - X >= 0) True >>> >>> y3 = sympy.Min(y1, y2) >>> Y3 = X + y3 - 5000 >>> bool(Y3 - X >= 0) Traceback (most recent call last): File "<stdin>", line 1, in <module> File ".../venv/lib/python3.9/site-packages/sympy/core/relational.py", line 511, in __bool__ raise TypeError("cannot determine truth value of Relational") TypeError: cannot determine truth value of Relational ``` If you assume that `X` is positive symbol (`symbolic_shape` [does assume](https://github.com/microsoft/onnxruntime/blob/de7a868d5f3390d7c095a53c26abd39f402f3f93/onnxruntime/python/tools/symbolic_shape_infer.py#L2129) this for graph inputs), then both `Y1 >= X` and `Y2 >= X` holds, and SymPy can prove this. This means that `Y3 >= X` also holds (since `Y3` is essentially equal to either `Y1` or `Y2`, depending on the value of `X`), but this is too hard for SymPy to prove. I confirmed that this is still the case for the latest SymPy version (`1.11.1`). This commit tries to fix this edge case by slightly rewriting the expression containing `sympy.Min()`. I explain the details in the comments in `symbolic_shape_infer.py`, so I won't duplicate them in the PR description. ### Motivation and Context This sounds like a very contrived example, but it actually appeared in the wild when we tried to infer shapes for an ONNX graph exported from PyTorch that used relative-position multihead attention from Fairseq. The problematic line is [here](https://github.com/facebookresearch/fairseq/blob/7d050ada7d365b535bf7c634ed3bcaf1cc2930b1/fairseq/modules/espnet_multihead_attention.py#L192). In our codebase, we have something like `matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]` before we add `matrix_ac` and `matrix_bd`. `matrix_bd` is itself a result of another slice, hence its shape contains `sympy.Min()`, and the SymPy weirdness described above prevents `symbolic_shape_infer.py` from correctly inferring the final shape of `matrix_bd`. Then `symbolic_shape_infer.py` explodes when we try to add `matrix_ac` and `matrix_bd`, because their shapes are not compatible. I added a small self-contained unit test to illustrate the problem. *Without* the fix, `slice_out_cropped` has shape `[N + Min(42, N + 21) - 22]`, and `input` has shape `[N]`, and we get this: ``` > python onnxruntime_test_python_symbolic_shape_infer.py ..................Cannot determine if 22 - N < 0 Unable to determine if N <= N + Min(42, N + 21) - 22, treat as equal E.... ====================================================================== ERROR: test_slice_of_min (__main__.TestSymbolicShapeInferenceForSlice) ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/dfyz/onnxruntime/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py", line 460, in test_slice_of_min model = SymbolicShapeInference.infer_shapes(onnx.helper.make_model(graph_def)) File "/home/dfyz/onnxruntime/onnxruntime/test/python/../../python/tools/symbolic_shape_infer.py", line 2461, in infer_shapes raise Exception("Incomplete symbolic shape inference") Exception: Incomplete symbolic shape inference ---------------------------------------------------------------------- Ran 23 tests in 0.486s FAILED (errors=1) ``` *With* the fix, both tensors have shape `[N]`, and the test passes. --------- Co-authored-by: Ivan Komarov <dfyz@yandex-team.ru>