Skip to content

Commit ad87510

Browse files
Support generating queries containing expression parameter binding during AOT run.
1 parent eeb06cd commit ad87510

File tree

10 files changed

+247
-120
lines changed

10 files changed

+247
-120
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
package org.springframework.data.mongodb.repository.aot;
1717

1818
import java.util.ArrayList;
19+
import java.util.LinkedHashMap;
1920
import java.util.List;
21+
import java.util.Map;
2022
import java.util.stream.Collectors;
2123
import java.util.stream.Stream;
2224

@@ -150,7 +152,7 @@ static class AggregationCodeBlockBuilder {
150152

151153
private final AotQueryMethodGenerationContext context;
152154
private final MongoQueryMethod queryMethod;
153-
private final List<CodeBlock> arguments;
155+
private final Map<String, CodeBlock> arguments;
154156

155157
private AggregationInteraction source;
156158

@@ -160,7 +162,8 @@ static class AggregationCodeBlockBuilder {
160162
AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
161163

162164
this.context = context;
163-
this.arguments = context.getBindableParameterNames().stream().map(CodeBlock::of).collect(Collectors.toList());
165+
this.arguments = new LinkedHashMap<>();
166+
context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it)));
164167
this.queryMethod = queryMethod;
165168
}
166169

@@ -288,7 +291,7 @@ private CodeBlock aggregationOptions(String aggregationVariableName) {
288291
}
289292

290293
private CodeBlock aggregationStages(String stageListVariableName, Iterable<String> stages, int stageCount,
291-
List<CodeBlock> arguments) {
294+
Map<String, CodeBlock> arguments) {
292295

293296
Builder builder = CodeBlock.builder();
294297
builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class,

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,25 @@
2121

2222
import org.bson.Document;
2323
import org.jspecify.annotations.Nullable;
24+
import org.springframework.data.expression.ValueEvaluationContext;
25+
import org.springframework.data.expression.ValueExpression;
26+
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
2427
import org.springframework.data.mongodb.BindableMongoExpression;
2528
import org.springframework.data.mongodb.core.MongoOperations;
2629
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
2730
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
2831
import org.springframework.data.mongodb.core.convert.MongoConverter;
2932
import org.springframework.data.mongodb.core.mapping.FieldName;
3033
import org.springframework.data.mongodb.core.query.BasicQuery;
34+
import org.springframework.data.mongodb.repository.query.MongoParameters;
35+
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
36+
import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec;
37+
import org.springframework.data.mongodb.util.json.ValueProvider;
3138
import org.springframework.data.projection.ProjectionFactory;
3239
import org.springframework.data.repository.core.RepositoryMetadata;
3340
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
41+
import org.springframework.data.repository.query.ValueExpressionDelegate;
42+
import org.springframework.expression.EvaluationContext;
3443
import org.springframework.util.ClassUtils;
3544
import org.springframework.util.ObjectUtils;
3645

@@ -46,31 +55,69 @@ public class MongoAotRepositoryFragmentSupport {
4655
private final MongoOperations mongoOperations;
4756
private final MongoConverter mongoConverter;
4857
private final ProjectionFactory projectionFactory;
58+
private final ValueExpressionDelegate valueExpressionDelegate;
4959

5060
protected MongoAotRepositoryFragmentSupport(MongoOperations mongoOperations,
5161
RepositoryFactoryBeanSupport.FragmentCreationContext context) {
52-
this(mongoOperations, context.getRepositoryMetadata(), context.getProjectionFactory());
62+
this(mongoOperations, context.getRepositoryMetadata(), context.getProjectionFactory(),
63+
context.getValueExpressionDelegate());
5364
}
5465

5566
protected MongoAotRepositoryFragmentSupport(MongoOperations mongoOperations, RepositoryMetadata repositoryMetadata,
56-
ProjectionFactory projectionFactory) {
67+
ProjectionFactory projectionFactory, ValueExpressionDelegate valueExpressionDelegate) {
5768

5869
this.mongoOperations = mongoOperations;
5970
this.mongoConverter = mongoOperations.getConverter();
6071
this.repositoryMetadata = repositoryMetadata;
6172
this.projectionFactory = projectionFactory;
73+
this.valueExpressionDelegate = valueExpressionDelegate;
6274
}
6375

6476
protected Document bindParameters(String source, Object[] parameters) {
6577
return new BindableMongoExpression(source, this.mongoConverter, parameters).toDocument();
6678
}
6779

80+
protected Document bindParameters(String source, Map<String, Object> parameters) {
81+
82+
ValueEvaluationContext valueEvaluationContext = this.valueExpressionDelegate.getEvaluationContextAccessor()
83+
.create(new NoMongoParameters()).getEvaluationContext(parameters.values());
84+
85+
EvaluationContext evaluationContext = valueEvaluationContext.getEvaluationContext();
86+
parameters.forEach(evaluationContext::setVariable);
87+
88+
ParameterBindingContext bindingContext = new ParameterBindingContext(new ValueProvider() {
89+
90+
private final List<Object> args = new ArrayList<>(parameters.values());
91+
92+
@Override
93+
public @Nullable Object getBindableValue(int index) {
94+
return args.get(index);
95+
}
96+
}, new ValueExpressionEvaluator() {
97+
98+
@Override
99+
@SuppressWarnings("unchecked")
100+
public <T> @Nullable T evaluate(String expression) {
101+
ValueExpression parse = valueExpressionDelegate.getValueExpressionParser().parse(expression);
102+
return (T) parse.evaluate(valueEvaluationContext);
103+
}
104+
});
105+
106+
return new ParameterBindingDocumentCodec().decode(source, bindingContext);
107+
}
108+
68109
protected BasicQuery createQuery(String queryString, Object[] parameters) {
69110

70111
Document queryDocument = bindParameters(queryString, parameters);
71112
return new BasicQuery(queryDocument);
72113
}
73114

115+
protected BasicQuery createQuery(String queryString, Map<String, Object> parameters) {
116+
117+
Document queryDocument = bindParameters(queryString, parameters);
118+
return new BasicQuery(queryDocument);
119+
}
120+
74121
protected AggregationPipeline createPipeline(List<Object> rawStages) {
75122

76123
List<AggregationOperation> stages = new ArrayList<>(rawStages.size());
@@ -151,4 +198,10 @@ private static <T> T getPotentiallyConvertedSimpleTypeValue(MongoConverter conve
151198
return converter.getConversionService().convert(value, targetType);
152199
}
153200

201+
static class NoMongoParameters extends MongoParameters {
202+
203+
NoMongoParameters() {
204+
super();
205+
}
206+
}
154207
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
package org.springframework.data.mongodb.repository.aot;
1717

1818
import java.util.Iterator;
19-
import java.util.List;
19+
import java.util.Map;
20+
import java.util.Map.Entry;
2021
import java.util.regex.Pattern;
2122

2223
import org.bson.Document;
@@ -47,6 +48,7 @@
4748
class MongoCodeBlocks {
4849

4950
private static final Pattern PARAMETER_BINDING_PATTERN = Pattern.compile("\\?(\\d+)");
51+
private static final Pattern EXPRESSION_BINDING_PATTERN = Pattern.compile("[\\?:][#$]\\{.*\\}");
5052

5153
/**
5254
* Builder for generating query parsing {@link CodeBlock}.
@@ -166,29 +168,69 @@ static GeoNearExecutionCodeBlockBuilder geoNearExecutionBlockBuilder(AotQueryMet
166168
return new GeoNearExecutionCodeBlockBuilder(context, queryMethod);
167169
}
168170

169-
static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, List<CodeBlock> arguments) {
171+
static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName,
172+
Map<String, CodeBlock> arguments) {
170173

171174
Builder builder = CodeBlock.builder();
172175
if (!StringUtils.hasText(source)) {
173176
builder.addStatement("$1T $2L = new $1T()", Document.class, variableName);
174177
} else if (!containsPlaceholder(source)) {
175178
builder.addStatement("$1T $2L = $1T.parse($3S)", Document.class, variableName, source);
176179
} else {
180+
builder.add("$T $L = bindParameters($S, ", Document.class, variableName, source);
181+
if (containsNamedPlaceholder(source)) {
182+
renderArgumentMap(arguments);
183+
} else {
184+
builder.add(renderArgumentArray(arguments));
185+
}
186+
builder.add(");\n");
187+
}
188+
return builder.build();
189+
}
190+
191+
static CodeBlock renderArgumentMap(Map<String, CodeBlock> arguments) {
177192

178-
builder.add("$T $L = bindParameters($S, new $T[]{ ", Document.class, variableName, source, Object.class);
179-
Iterator<CodeBlock> iterator = arguments.iterator();
180-
while (iterator.hasNext()) {
181-
builder.add(iterator.next());
182-
if (iterator.hasNext()) {
183-
builder.add(", ");
184-
}
193+
Builder builder = CodeBlock.builder();
194+
builder.add("$T.of(", Map.class);
195+
Iterator<Entry<String, CodeBlock>> iterator = arguments.entrySet().iterator();
196+
while (iterator.hasNext()) {
197+
Entry<String, CodeBlock> next = iterator.next();
198+
builder.add("$S, ", next.getKey());
199+
builder.add(next.getValue());
200+
if (iterator.hasNext()) {
201+
builder.add(", ");
185202
}
186-
builder.add("});\n");
187203
}
204+
builder.add(")");
205+
return builder.build();
206+
}
207+
208+
static CodeBlock renderArgumentArray(Map<String, CodeBlock> arguments) {
209+
210+
Builder builder = CodeBlock.builder();
211+
builder.add("new $T[]{ ", Object.class);
212+
Iterator<CodeBlock> iterator = arguments.values().iterator();
213+
while (iterator.hasNext()) {
214+
builder.add(iterator.next());
215+
if (iterator.hasNext()) {
216+
builder.add(", ");
217+
} else {
218+
builder.add(" ");
219+
}
220+
}
221+
builder.add("}");
188222
return builder.build();
189223
}
190224

191225
static boolean containsPlaceholder(String source) {
226+
return containsIndexedPlaceholder(source) || containsNamedPlaceholder(source);
227+
}
228+
229+
static boolean containsNamedPlaceholder(String source) {
230+
return EXPRESSION_BINDING_PATTERN.matcher(source).find();
231+
}
232+
233+
static boolean containsIndexedPlaceholder(String source) {
192234
return PARAMETER_BINDING_PATTERN.matcher(source).find();
193235
}
194236

@@ -198,8 +240,8 @@ static void appendReadPreference(AotQueryMethodGenerationContext context, Builde
198240
String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null;
199241

200242
if (StringUtils.hasText(readPreference)) {
201-
builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName,
202-
com.mongodb.ReadPreference.class, readPreference);
243+
builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName, com.mongodb.ReadPreference.class,
244+
readPreference);
203245
}
204246
}
205247
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
import java.lang.reflect.Method;
3030
import java.util.Locale;
31-
import java.util.regex.Pattern;
3231

3332
import org.apache.commons.logging.Log;
3433
import org.apache.commons.logging.LogFactory;
@@ -117,18 +116,6 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB
117116
return nearQueryMethodContributor(queryMethod, near);
118117
}
119118

120-
if (queryMethod.hasAnnotatedQuery()) {
121-
if (StringUtils.hasText(queryMethod.getAnnotatedQuery())
122-
&& Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) {
123-
124-
if (logger.isDebugEnabled()) {
125-
logger.debug(
126-
"Skipping AOT generation for [%s]. SpEL expressions are not supported".formatted(method.getName()));
127-
}
128-
return MethodContributor.forQueryMethod(queryMethod).metadataOnly(query);
129-
}
130-
}
131-
132119
if (query.isDelete()) {
133120
return deleteMethodContributor(queryMethod, query);
134121
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
*/
1616
package org.springframework.data.mongodb.repository.aot;
1717

18-
import java.util.ArrayList;
19-
import java.util.Iterator;
18+
import java.util.LinkedHashMap;
2019
import java.util.List;
20+
import java.util.Map;
2121
import java.util.Optional;
2222

2323
import org.bson.Document;
@@ -147,14 +147,14 @@ static class QueryCodeBlockBuilder {
147147
private final MongoQueryMethod queryMethod;
148148

149149
private QueryInteraction source;
150-
private final List<CodeBlock> arguments;
150+
private final Map<String, CodeBlock> arguments;
151151
private String queryVariableName;
152152

153153
QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
154154

155155
this.context = context;
156156

157-
this.arguments = new ArrayList<>();
157+
this.arguments = new LinkedHashMap<>();
158158
this.queryMethod = queryMethod;
159159
collectArguments(context);
160160

@@ -167,29 +167,29 @@ private void collectArguments(AotQueryMethodGenerationContext context) {
167167
if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) {
168168

169169
// renders as generic $geometry, thus can be handled by the converter when parsing
170-
arguments.add(CodeBlock.of(parameterName));
170+
arguments.put(parameterName, CodeBlock.of(parameterName));
171171
} else if (ClassUtils.isAssignable(Circle.class, parameter.getType())
172172
|| ClassUtils.isAssignable(Sphere.class, parameter.getType())) {
173173

174174
// $center | $centerSphere : [ [ <x>, <y> ], <radius> ]
175-
arguments.add(CodeBlock.builder().add(
175+
arguments.put(parameterName, CodeBlock.builder().add(
176176
"$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())",
177177
List.class, parameterName).build());
178178
} else if (ClassUtils.isAssignable(Box.class, parameter.getType())) {
179179

180180
// $box: [ [ <x1>, <y1> ], [ <x2>, <y2> ] ]
181-
arguments.add(CodeBlock.builder().add(
181+
arguments.put(parameterName, CodeBlock.builder().add(
182182
"$1T.of($1T.of($2L.getFirst().getX(), $2L.getFirst().getY()), $1T.of($2L.getSecond().getX(), $2L.getSecond().getY()))",
183183
List.class, parameterName).build());
184184
} else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) {
185185

186186
// $polygon: [ [ <x1> , <y1> ], [ <x2> , <y2> ], [ <x3> , <y3> ], ... ]
187187
String localVar = context.localVariable("_p");
188-
arguments.add(
188+
arguments.put(parameterName,
189189
CodeBlock.builder().add("$1L.getPoints().stream().map($2L -> $3T.of($2L.getX(), $2L.getY())).toList()",
190190
parameterName, localVar, List.class).build());
191191
} else {
192-
arguments.add(CodeBlock.of(parameterName));
192+
arguments.put(parameterName, CodeBlock.of(parameterName));
193193
}
194194
}
195195
}
@@ -284,17 +284,13 @@ private CodeBlock renderExpressionToQuery(@Nullable String source, String variab
284284
builder.addStatement("$1T $2L = new $1T($3T.parse($4S))", BasicQuery.class, variableName, Document.class,
285285
source);
286286
} else {
287-
builder.add("$T $L = createQuery($S, new $T[]{ ", BasicQuery.class, variableName, source, Object.class);
288-
Iterator<CodeBlock> iterator = arguments.iterator();
289-
while (iterator.hasNext()) {
290-
builder.add(iterator.next());
291-
if (iterator.hasNext()) {
292-
builder.add(", ");
293-
} else {
294-
builder.add(" ");
295-
}
287+
builder.add("$T $L = createQuery($S, ", BasicQuery.class, variableName, source);
288+
if (MongoCodeBlocks.containsNamedPlaceholder(source)) {
289+
builder.add(MongoCodeBlocks.renderArgumentMap(arguments));
290+
} else {
291+
builder.add(MongoCodeBlocks.renderArgumentArray(arguments));
296292
}
297-
builder.add("});\n");
293+
builder.add(");\n");
298294
}
299295

300296
return builder.build();

0 commit comments

Comments
 (0)