diff --git a/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java b/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java index db3d517f0b..8538ebbae6 100644 --- a/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java +++ b/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java @@ -506,4 +506,45 @@ SELECT name, ROW_NUMBER() OVER (ORDER BY id) AS rn FROM catalog.employees LogicalTableScan(table=[[catalog, employees]]) """); } + + @Test + public void testGroupByExpression() { + givenQuery("SELECT LENGTH(name), COUNT(*) FROM catalog.employees GROUP BY LENGTH(name)") + .assertPlan( + """ + LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()]) + LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } + + @Test + public void testHavingOnGroupByExpression() { + givenQuery( + "SELECT COUNT(*) FROM catalog.employees GROUP BY LENGTH(name) HAVING LENGTH(name) > 3") + .assertPlan( + """ + LogicalProject(COUNT(*)=[$0]) + LogicalFilter(condition=[>($1, 3)]) + LogicalProject(COUNT(*)=[$1], LENGTH(name)=[$0]) + LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()]) + LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } + + @Test + public void testOrderByGroupByExpression() { + givenQuery( + """ + SELECT LENGTH(name) FROM catalog.employees GROUP BY LENGTH(name) ORDER BY LENGTH(name) + """) + .assertPlan( + """ + LogicalSort(sort0=[$0], dir0=[ASC-nulls-first]) + LogicalAggregate(group=[{0}]) + LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } } diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java index f18ac0b047..7ca7ab0930 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java @@ -22,6 +22,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.tools.FrameworkConfig; import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.HighlightConfig; import org.opensearch.sql.calcite.utils.CalciteToolsHelper; @@ -71,6 +72,9 @@ public class CalcitePlanContext { */ @Getter private final Map aggregateOutputIndex = new HashMap<>(); + /** Maps GROUP BY Function AST nodes to their output field index for post-aggregate resolution. */ + @Getter private final Map groupKeyOutputIndex = new HashMap<>(); + /** * List of captured variables from outer scope for lambda functions. When a lambda body references * a field that is not a lambda parameter, it gets captured and stored here. The captured diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 2ed139d781..2658066bfb 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -1788,6 +1788,14 @@ private void visitAggregation( context.getAggregateOutputIndex().put(aggFunc, aggStartIdx + i); } } + context.getGroupKeyOutputIndex().clear(); + int groupStartIdx = metricsFirst ? aggRexList.size() : 0; + for (int i = 0; i < groupExprList.size(); i++) { + Function groupFunc = extractFunction(groupExprList.get(i)); + if (groupFunc != null) { + context.getGroupKeyOutputIndex().put(groupFunc, groupStartIdx + i); + } + } } private static AggregateFunction extractAggregateFunction(UnresolvedExpression expr) { @@ -1796,6 +1804,12 @@ private static AggregateFunction extractAggregateFunction(UnresolvedExpression e return null; } + private static Function extractFunction(UnresolvedExpression expr) { + if (expr instanceof Function f) return f; + if (expr instanceof Alias alias) return extractFunction(alias.getDelegated()); + return null; + } + /** * Collects input refs used by aggregate FILTER(WHERE ...) predicates so trimming retains them. */ diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index 1d473b168e..0c1326f868 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -563,6 +563,12 @@ private List modifyLambdaTypeByFunction( @Override public RexNode visitFunction(Function node, CalcitePlanContext context) { + // Post-aggregate, a GROUP BY function expression is a materialized output column; reference it + // instead of recomputing from base fields the aggregation removed. + Integer groupKeyIndex = context.getGroupKeyOutputIndex().get(node); + if (groupKeyIndex != null) { + return context.relBuilder.field(groupKeyIndex); + } List args = node.getFuncArgs(); List arguments = new ArrayList<>();