Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -506,4 +506,45 @@ SELECT name, ROW_NUMBER() OVER (ORDER BY id) AS rn FROM catalog.employees
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testGroupByExpression() {
givenQuery("SELECT LENGTH(name), COUNT(*) FROM catalog.employees GROUP BY LENGTH(name)")
.assertPlan(
"""
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testHavingOnGroupByExpression() {
givenQuery(
"SELECT COUNT(*) FROM catalog.employees GROUP BY LENGTH(name) HAVING LENGTH(name) > 3")
.assertPlan(
"""
LogicalProject(COUNT(*)=[$0])
LogicalFilter(condition=[>($1, 3)])
LogicalProject(COUNT(*)=[$1], LENGTH(name)=[$0])
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testOrderByGroupByExpression() {
givenQuery(
"""
SELECT LENGTH(name) FROM catalog.employees GROUP BY LENGTH(name) ORDER BY LENGTH(name)
""")
.assertPlan(
"""
LogicalSort(sort0=[$0], dir0=[ASC-nulls-first])
LogicalAggregate(group=[{0}])
LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)])
LogicalTableScan(table=[[catalog, employees]])
""");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.FrameworkConfig;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.HighlightConfig;
import org.opensearch.sql.calcite.utils.CalciteToolsHelper;
Expand Down Expand Up @@ -71,6 +72,9 @@ public class CalcitePlanContext {
*/
@Getter private final Map<AggregateFunction, Integer> aggregateOutputIndex = new HashMap<>();

/** Maps GROUP BY Function AST nodes to their output field index for post-aggregate resolution. */
@Getter private final Map<Function, Integer> groupKeyOutputIndex = new HashMap<>();

/**
* List of captured variables from outer scope for lambda functions. When a lambda body references
* a field that is not a lambda parameter, it gets captured and stored here. The captured
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,14 @@ private void visitAggregation(
context.getAggregateOutputIndex().put(aggFunc, aggStartIdx + i);
}
}
context.getGroupKeyOutputIndex().clear();
int groupStartIdx = metricsFirst ? aggRexList.size() : 0;
for (int i = 0; i < groupExprList.size(); i++) {
Function groupFunc = extractFunction(groupExprList.get(i));
if (groupFunc != null) {
context.getGroupKeyOutputIndex().put(groupFunc, groupStartIdx + i);
}
}
}

private static AggregateFunction extractAggregateFunction(UnresolvedExpression expr) {
Expand All @@ -1796,6 +1804,12 @@ private static AggregateFunction extractAggregateFunction(UnresolvedExpression e
return null;
}

private static Function extractFunction(UnresolvedExpression expr) {
if (expr instanceof Function f) return f;
if (expr instanceof Alias alias) return extractFunction(alias.getDelegated());
return null;
}

/**
* Collects input refs used by aggregate FILTER(WHERE ...) predicates so trimming retains them.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,12 @@ private List<RelDataType> modifyLambdaTypeByFunction(

@Override
public RexNode visitFunction(Function node, CalcitePlanContext context) {
// Post-aggregate, a GROUP BY function expression is a materialized output column; reference it
// instead of recomputing from base fields the aggregation removed.
Integer groupKeyIndex = context.getGroupKeyOutputIndex().get(node);
if (groupKeyIndex != null) {
return context.relBuilder.field(groupKeyIndex);
}
List<UnresolvedExpression> args = node.getFuncArgs();
List<RexNode> arguments = new ArrayList<>();

Expand Down
Loading