Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;

/**
Expand Down Expand Up @@ -309,8 +308,12 @@ private Optional<CelMutableAst> maybeFold(
return maybeRewriteOptional(optResult, mutableAst, node.expr());
}

return maybeAdaptEvaluatedResult(result)
.map(celExpr -> astMutator.replaceSubtree(mutableAst, celExpr, node.id()));
CelMutableExpr adaptedResult = maybeAdaptEvaluatedResult(result).orElse(null);
if (adaptedResult == null) {
return Optional.empty();
}

return Optional.of(astMutator.replaceSubtree(mutableAst, adaptedResult, node.id()));
}

private Optional<CelMutableExpr> maybeAdaptEvaluatedResult(Object result) {
Expand All @@ -331,7 +334,7 @@ private Optional<CelMutableExpr> maybeAdaptEvaluatedResult(Object result) {
} else if (result instanceof Map<?, ?>) {
Map<?, ?> map = (Map<?, ?>) result;
List<CelMutableMap.Entry> mapEntries = new ArrayList<>();
for (Entry<?, ?> entry : map.entrySet()) {
for (Map.Entry<?, ?> entry : map.entrySet()) {
CelMutableExpr adaptedKey = maybeAdaptEvaluatedResult(entry.getKey()).orElse(null);
if (adaptedKey == null) {
return Optional.empty();
Expand Down Expand Up @@ -384,16 +387,15 @@ private Optional<CelMutableAst> maybeRewriteOptional(
return Optional.empty();
}

if (!CelConstant.isConstantValue(unwrappedResult)) {
// Evaluated result is not a constant. Leave the optional as is.
CelMutableExpr adaptedResult = maybeAdaptEvaluatedResult(unwrappedResult).orElse(null);
if (adaptedResult == null) {
// Evaluated result is not an adaptable constant. Leave the optional as is.
return Optional.empty();
}

CelMutableExpr newOptionalOfCall =
CelMutableExpr.ofCall(
CelMutableCall.create(
Function.OPTIONAL_OF.getFunction(),
CelMutableExpr.ofConstant(CelConstant.ofObjectValue(unwrappedResult))));
CelMutableCall.create(Function.OPTIONAL_OF.getFunction(), adaptedResult));

return Optional.of(astMutator.replaceSubtree(mutableAst, newOptionalOfCall, expr.id()));
}
Expand Down Expand Up @@ -530,6 +532,37 @@ private Optional<CelMutableAst> maybeShortCircuitCall(
"Folding variadic logical operator is not supported yet.");
}

private boolean isFoldedAggregateLiteral(CelMutableExpr expr) {
if (expr.getKind().equals(Kind.CONSTANT)) {
return true;
}
if (expr.getKind().equals(Kind.LIST)) {
for (CelMutableExpr child : expr.list().elements()) {
if (!isFoldedAggregateLiteral(child)) {
return false;
}
}
return true;
}
if (expr.getKind().equals(Kind.MAP)) {
for (CelMutableExpr.CelMutableMap.Entry entry : expr.map().entries()) {
if (!isFoldedAggregateLiteral(entry.key()) || !isFoldedAggregateLiteral(entry.value())) {
return false;
}
}
return true;
}
if (expr.getKind().equals(Kind.STRUCT)) {
for (CelMutableExpr.CelMutableStruct.Entry entry : expr.struct().entries()) {
if (!isFoldedAggregateLiteral(entry.value())) {
return false;
}
}
return true;
}
return false;
}

private CelMutableAst pruneOptionalElements(CelMutableAst ast) {
ImmutableList<CelMutableExpr> aggregateLiterals =
CelNavigableMutableExpr.fromExpr(ast.expr())
Expand Down Expand Up @@ -588,7 +621,7 @@ private CelMutableAst pruneOptionalListElements(CelMutableAst mutableAst, CelMut
continue;
} else if (call.function().equals(Function.OPTIONAL_OF.getFunction())) {
CelMutableExpr arg = call.args().get(0);
if (arg.getKind().equals(Kind.CONSTANT)) {
if (isFoldedAggregateLiteral(arg)) {
updatedElemBuilder.add(call.args().get(0));
continue;
}
Expand Down Expand Up @@ -629,7 +662,7 @@ private CelMutableAst pruneOptionalMapElements(CelMutableAst ast, CelMutableExpr
continue;
} else if (call.function().equals(Function.OPTIONAL_OF.getFunction())) {
CelMutableExpr arg = call.args().get(0);
if (arg.getKind().equals(Kind.CONSTANT)) {
if (isFoldedAggregateLiteral(arg)) {
modified = true;
entry.setOptionalEntry(false);
entry.setValue(call.args().get(0));
Expand Down Expand Up @@ -670,7 +703,7 @@ private CelMutableAst pruneOptionalStructElements(CelMutableAst ast, CelMutableE
continue;
} else if (call.function().equals(Function.OPTIONAL_OF.getFunction())) {
CelMutableExpr arg = call.args().get(0);
if (arg.getKind().equals(Kind.CONSTANT)) {
if (isFoldedAggregateLiteral(arg)) {
modified = true;
entry.setOptionalEntry(false);
entry.setValue(call.args().get(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,24 @@ private static Cel setupEnv(CelBuilder celBuilder) {
@TestParameters(
"{source: 'has({\"req\": \"Avail\"}.opt) ? ({\"req\": \"Avail\"}.req + \" \" +"
+ " {\"req\": \"Avail\"}.opt) : {\"req\": \"Avail\"}.req', expected: '\"Avail\"'}")
@TestParameters("{source: 'true || optional.none().hasValue()', expected: 'true'}")
@TestParameters("{source: 'false && map_var[?\"missing\"].hasValue()', expected: 'false'}")
@TestParameters("{source: '{\"hello\": [1, 2]}.?hello', expected: 'optional.of([1, 2])'}")
@TestParameters(
"{source: '{?\"key\": optional.of({\"a\": 1})}', expected: '{\"key\": {\"a\": 1}}'}")
@TestParameters(
"{source: 'TestAllTypes{?repeated_int32: optional.of([1, 2])}',"
+ " expected: 'cel.expr.conformance.proto3.TestAllTypes{repeated_int32: [1, 2]}'}")
@TestParameters("{source: '[?optional.of([1, x])]', expected: '[?optional.of([1, x])]'}")
@TestParameters("{source: '[?optional.of({\"a\": x})]', expected: '[?optional.of({\"a\": x})]'}")
@TestParameters("{source: '[?optional.of({x: 1})]', expected: '[?optional.of({x: 1})]'}")
@TestParameters(
"{source: '[?optional.of(TestAllTypes{single_int32: x})]', expected:"
+ " '[?optional.of(cel.expr.conformance.proto3.TestAllTypes{single_int32: x})]'}")
@TestParameters(
"{source: '[?optional.of(TestAllTypes{single_int32: 1})]', expected:"
+ " '[cel.expr.conformance.proto3.TestAllTypes{single_int32: 1}]'}")
@TestParameters("{source: '[?optional.of(x)]', expected: '[?optional.of(x)]'}")
// TODO: Support folding lists with mixed types. This requires mutable lists.
// @TestParameters("{source: 'dyn([1]) + [1.0]'}")
public void constantFold_success(String source, String expected) throws Exception {
Expand Down Expand Up @@ -560,6 +578,4 @@ public void iterationLimitReached_throws() throws Exception {
assertThrows(CelOptimizationException.class, () -> optimizer.optimize(ast));
assertThat(e).hasMessageThat().contains("Optimization failure: Max iteration count reached.");
}


}
Loading