Skip to content

SLATE: Fix codegen for elimination of 1x1 block#5205

Open
pbrubeck wants to merge 1 commit into
releasefrom
pbrubeck/fix/slate-1x1
Open

SLATE: Fix codegen for elimination of 1x1 block#5205
pbrubeck wants to merge 1 commit into
releasefrom
pbrubeck/fix/slate-1x1

Conversation

@pbrubeck

@pbrubeck pbrubeck commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Description

Fix a codegen index mismatch error when using SLATE for elimination of a 1x1 block, elimination of larger blocks was not previously affected.

  • The fix has effect on ComponentTensor in general (not only for slate expressions)
  • Improve SLATE API to allow Block(T, (i, j)) as a shorthand for Block(T, ((i, ), (j, )))

Comment thread tsfc/loopy.py
idx = ctx.gem_to_pym_multiindex(expr.multiindex)
var, sub_idx = ctx.pymbolic_variable_and_destruct(expr)
lhs = p.Subscript(var, idx + sub_idx)
lhs = p.Subscript(var, sub_idx + idx)

@pbrubeck pbrubeck Jun 25, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codex CLI:

The Slate kernel contains the invalid assignment: t1[i_3, i_1, i_2] = T1[i_1, i_3] * t0[i_3, i_2] with t1 shaped (4,1,1). That
narrows it to GEM/Impero lowering of (4x1) * (1x1) rather than wrapper argument merging.

The mismatch is concrete now. For a ComponentTensor temporary with free indices, Loopy declares shape as (free-index extents) + tensor-shape, and
later references it that way; but the evaluation assignment writes (tensor component indices) + (free indices). It only becomes a bounds error when the extents differ, like (4, 1, 1).

The code fix is effective and now produces the corrected subscript t1[i_1, i_2, i_3].

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Human: Where is loopy declaring the free-index before the tensor shape?

Codex:

The key line is tsfc/loopy.py:242:

shape = tuple([i.extent for i in ctx.indices[temp]]) + temp.shape

That is where TSFC’s Loopy backend builds the loopy.TemporaryVariable shape as:

(free-index extents) + (GEM tensor shape)

The next line, tsfc/loopy.py:243, passes that shape into lp.TemporaryVariable(...).

The ctx.indices mapping comes from impero_c.indices at tsfc/loopy.py:229, and those entries are computed in GEM’s declaration placement at /
scratch/brubeckmarti/firedrake-venv/src/fiat/gem/impero_utils.py:243:

indices[e] = get_indices(e)[len(loop_indices):]

References to those temporaries also use ctx.indices[node]; see tsfc/loopy.py:174, especially line 177.

@pbrubeck pbrubeck Jun 25, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned by the fact that there is a rank-3 temporary. We are only doing matrix-matrix multiplications, so I am only expecting temporaries up to rank-2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about indices for vector/tensor shape?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about indices for vector/tensor shape?

We expect those when computing the matrix entries, we do not see them in the matrix-matrix product for contracted extent larger than 1, we only get the rank-3 temporary when the contracted extent is equal to 1.

The rank-3 tensor arising from matrix-matrix multiplication is this one

ABikj = Product(Indexed(A, tuple(i + [k])),
Indexed(B, tuple([k] + j)))
return ComponentTensor(IndexSum(ABikj, (k, )), tuple(i + j))

We have special handling for IndexSum over an index with extent=1
https://github.com/firedrakeproject/fiat/blob/2d80c42e4720a376d8bcee89ff0eb210100f39c3/gem/gem.py#L903-L910

That code is introducing the extra rank, which triggers the issue that this PR is fixing.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add branching to _slate2gem_mul to prevent the extra rank

diff --git a/firedrake/slate/slac/utils.py b/firedrake/slate/slac/utils.py
index bd5c81e23..57972ed73 100644
--- a/firedrake/slate/slac/utils.py
+++ b/firedrake/slate/slac/utils.py
@@ -134,9 +134,14 @@ def _slate2gem_mul(expr, self):
     A, B = map(self, expr.children)
     *i, k = tuple(make_indices(len(A.shape)))
     _, *j = tuple(make_indices(len(B.shape)))
-    ABikj = Product(Indexed(A, tuple(i + [k])),
-                    Indexed(B, tuple([k] + j)))
-    return ComponentTensor(IndexSum(ABikj, (k, )), tuple(i + j))
+    if A.shape[-1] == 1:
+        ABij = Product(Indexed(A, tuple(i + [0])),
+                       Indexed(B, tuple([0] + j)))
+        return ComponentTensor(ABij, tuple(i + j))
+    else:
+        ABikj = Product(Indexed(A, tuple(i + [k])),
+                        Indexed(B, tuple([k] + j)))
+        return ComponentTensor(IndexSum(ABikj, (k, )), tuple(i + j))

But I don't think we should be worried about an extra rank of dimension 1 if we know how to deal with it properly. So I think that we should leave the PR as it is.

@pbrubeck pbrubeck requested a review from connorjward June 25, 2026 08:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants