SLATE: Fix codegen for elimination of 1x1 block#5205
Conversation
| idx = ctx.gem_to_pym_multiindex(expr.multiindex) | ||
| var, sub_idx = ctx.pymbolic_variable_and_destruct(expr) | ||
| lhs = p.Subscript(var, idx + sub_idx) | ||
| lhs = p.Subscript(var, sub_idx + idx) |
There was a problem hiding this comment.
Codex CLI:
The Slate kernel contains the invalid assignment: t1[i_3, i_1, i_2] = T1[i_1, i_3] * t0[i_3, i_2] with t1 shaped (4,1,1). That
narrows it to GEM/Impero lowering of (4x1) * (1x1) rather than wrapper argument merging.
The mismatch is concrete now. For a ComponentTensor temporary with free indices, Loopy declares shape as (free-index extents) + tensor-shape, and
later references it that way; but the evaluation assignment writes (tensor component indices) + (free indices). It only becomes a bounds error when the extents differ, like (4, 1, 1).
The code fix is effective and now produces the corrected subscript t1[i_1, i_2, i_3].
There was a problem hiding this comment.
Human: Where is loopy declaring the free-index before the tensor shape?
Codex:
The key line is tsfc/loopy.py:242:
shape = tuple([i.extent for i in ctx.indices[temp]]) + temp.shape
That is where TSFC’s Loopy backend builds the loopy.TemporaryVariable shape as:
(free-index extents) + (GEM tensor shape)
The next line, tsfc/loopy.py:243, passes that shape into lp.TemporaryVariable(...).
The ctx.indices mapping comes from impero_c.indices at tsfc/loopy.py:229, and those entries are computed in GEM’s declaration placement at /
scratch/brubeckmarti/firedrake-venv/src/fiat/gem/impero_utils.py:243:indices[e] = get_indices(e)[len(loop_indices):]
References to those temporaries also use ctx.indices[node]; see tsfc/loopy.py:174, especially line 177.
There was a problem hiding this comment.
I'm concerned by the fact that there is a rank-3 temporary. We are only doing matrix-matrix multiplications, so I am only expecting temporaries up to rank-2
There was a problem hiding this comment.
What about indices for vector/tensor shape?
There was a problem hiding this comment.
What about indices for vector/tensor shape?
We expect those when computing the matrix entries, we do not see them in the matrix-matrix product for contracted extent larger than 1, we only get the rank-3 temporary when the contracted extent is equal to 1.
The rank-3 tensor arising from matrix-matrix multiplication is this one
firedrake/firedrake/slate/slac/utils.py
Lines 137 to 139 in fda580f
We have special handling for IndexSum over an index with extent=1
https://github.com/firedrakeproject/fiat/blob/2d80c42e4720a376d8bcee89ff0eb210100f39c3/gem/gem.py#L903-L910
That code is introducing the extra rank, which triggers the issue that this PR is fixing.
There was a problem hiding this comment.
We could add branching to _slate2gem_mul to prevent the extra rank
diff --git a/firedrake/slate/slac/utils.py b/firedrake/slate/slac/utils.py
index bd5c81e23..57972ed73 100644
--- a/firedrake/slate/slac/utils.py
+++ b/firedrake/slate/slac/utils.py
@@ -134,9 +134,14 @@ def _slate2gem_mul(expr, self):
A, B = map(self, expr.children)
*i, k = tuple(make_indices(len(A.shape)))
_, *j = tuple(make_indices(len(B.shape)))
- ABikj = Product(Indexed(A, tuple(i + [k])),
- Indexed(B, tuple([k] + j)))
- return ComponentTensor(IndexSum(ABikj, (k, )), tuple(i + j))
+ if A.shape[-1] == 1:
+ ABij = Product(Indexed(A, tuple(i + [0])),
+ Indexed(B, tuple([0] + j)))
+ return ComponentTensor(ABij, tuple(i + j))
+ else:
+ ABikj = Product(Indexed(A, tuple(i + [k])),
+ Indexed(B, tuple([k] + j)))
+ return ComponentTensor(IndexSum(ABikj, (k, )), tuple(i + j))
But I don't think we should be worried about an extra rank of dimension 1 if we know how to deal with it properly. So I think that we should leave the PR as it is.
Description
Fix a codegen index mismatch error when using SLATE for elimination of a 1x1 block, elimination of larger blocks was not previously affected.
ComponentTensorin general (not only for slate expressions)Block(T, (i, j))as a shorthand forBlock(T, ((i, ), (j, )))