RuleErrorCollectorJUnitPlugin.java
/*******************************************************************************
* Copyright (c) 2026 Carsten Hammer.
*
* This program and the accompanying materials
* are made available under the terms of the Eclipse Public License 2.0
* which accompanies this distribution, and is available at
* https://www.eclipse.org/legal/epl-2.0/
*
* SPDX-License-Identifier: EPL-2.0
*
* Contributors:
* Carsten Hammer
*******************************************************************************/
package org.sandbox.jdt.internal.corext.fix.helper;
import static org.sandbox.jdt.internal.corext.fix.helper.lib.JUnitConstants.*;
import java.util.ArrayList;
import java.util.List;
import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTNode;
import org.eclipse.jdt.core.dom.Block;
import org.eclipse.jdt.core.dom.Expression;
import org.eclipse.jdt.core.dom.ExpressionStatement;
import org.eclipse.jdt.core.dom.FieldDeclaration;
import org.eclipse.jdt.core.dom.ITypeBinding;
import org.eclipse.jdt.core.dom.LambdaExpression;
import org.eclipse.jdt.core.dom.MethodDeclaration;
import org.eclipse.jdt.core.dom.MethodInvocation;
import org.eclipse.jdt.core.dom.SimpleName;
import org.eclipse.jdt.core.dom.Statement;
import org.eclipse.jdt.core.dom.ThrowStatement;
import org.eclipse.jdt.core.dom.TypeDeclaration;
import org.eclipse.jdt.core.dom.VariableDeclarationFragment;
import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
import org.eclipse.jdt.core.dom.rewrite.ImportRewrite;
import org.eclipse.jdt.internal.corext.dom.ASTNodes;
import org.eclipse.text.edits.TextEditGroup;
import org.sandbox.jdt.internal.common.AstProcessorBuilder;
import org.sandbox.jdt.internal.common.ReferenceHolder;
import org.sandbox.jdt.internal.corext.fix.helper.lib.JunitHolder;
import org.sandbox.jdt.internal.corext.fix.helper.lib.TriggerPatternCleanupPlugin;
import org.sandbox.jdt.triggerpattern.api.CleanupPattern;
import org.sandbox.jdt.triggerpattern.api.Match;
import org.sandbox.jdt.triggerpattern.api.PatternKind;
/**
* Plugin to migrate JUnit 4 ErrorCollector rule to JUnit 5 assertAll.
*
* Transforms: - collector.checkThat(actual, matcher) → () -> assertThat(actual,
* matcher) - collector.addError(throwable) → () -> { throw throwable; } -
* collector.checkSucceeds(callable) → () -> callable.call()
*
* All transformations are wrapped in assertAll() per test method.
*
* @since 1.3.0
*/
@CleanupPattern(value = "@Rule public ErrorCollector $name", kind = PatternKind.FIELD, qualifiedType = ORG_JUNIT_RULES_ERROR_COLLECTOR, cleanupId = "cleanup.junit.ruleerrorcollector", description = "Migrate @Rule ErrorCollector to assertAll()", displayName = "JUnit 4 @Rule ErrorCollector \u2192 JUnit 5 assertAll()")
public class RuleErrorCollectorJUnitPlugin extends TriggerPatternCleanupPlugin {
@Override
protected JunitHolder createHolder(Match match) {
FieldDeclaration fieldDecl = (FieldDeclaration) match.getMatchedNode();
VariableDeclarationFragment fragment = (VariableDeclarationFragment) fieldDecl.fragments().get(0);
ITypeBinding binding = fragment.resolveBinding() != null ? fragment.resolveBinding().getType() : null;
boolean isErrorCollector;
if (binding != null && ORG_JUNIT_RULES_ERROR_COLLECTOR.equals(binding.getQualifiedName())) {
isErrorCollector = true;
} else {
// Fallback: check by source type name when binding is unavailable or recovered
String typeName = fieldDecl.getType().toString();
String simpleTypeName = ORG_JUNIT_RULES_ERROR_COLLECTOR
.substring(ORG_JUNIT_RULES_ERROR_COLLECTOR.lastIndexOf('.') + 1);
isErrorCollector = simpleTypeName.equals(typeName) || ORG_JUNIT_RULES_ERROR_COLLECTOR.equals(typeName);
}
if (!isErrorCollector) {
return null;
}
JunitHolder holder = new JunitHolder();
holder.setMinv(fieldDecl);
return holder;
}
@Override
protected void process2Rewrite(TextEditGroup group, ASTRewrite rewriter, AST ast, ImportRewrite importRewriter,
JunitHolder junitHolder) {
FieldDeclaration field = junitHolder.getFieldDeclaration();
TypeDeclaration parentClass = ASTNodes.getParent(field, TypeDeclaration.class);
VariableDeclarationFragment originalFragment = (VariableDeclarationFragment) field.fragments().get(0);
String fieldName = originalFragment.getName().getIdentifier();
// Remove the field declaration
rewriter.remove(field, group);
// Remove old imports
importRewriter.removeImport(ORG_JUNIT_RULE);
importRewriter.removeImport(ORG_JUNIT_RULES_ERROR_COLLECTOR);
// Add new imports
importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, "assertAll", false);
// Transform all test methods that use the ErrorCollector field
for (MethodDeclaration method : parentClass.getMethods()) {
transformTestMethod(method, fieldName, rewriter, ast, group, importRewriter);
}
}
private void transformTestMethod(MethodDeclaration method, String fieldName, ASTRewrite rewriter, AST ast,
TextEditGroup group, ImportRewrite importRewriter) {
Block methodBody = method.getBody();
if (methodBody == null) {
return;
}
List<Statement> statements = methodBody.statements();
if (statements.isEmpty()) {
return;
}
// Find all ErrorCollector method invocations in this method
List<ErrorCollectorCall> errorCollectorCalls = findErrorCollectorCalls(statements, fieldName);
if (errorCollectorCalls.isEmpty()) {
// This method doesn't use the ErrorCollector field
return;
}
// Create assertAll() call with lambda expressions for each error collector call
MethodInvocation assertAllCall = ast.newMethodInvocation();
assertAllCall.setName(ast.newSimpleName("assertAll"));
// Create lambda expressions for each ErrorCollector call
for (ErrorCollectorCall call : errorCollectorCalls) {
LambdaExpression lambda = createLambdaForErrorCollectorCall(call, ast, importRewriter);
assertAllCall.arguments().add(lambda);
}
// Create the new assertAll statement
ExpressionStatement assertAllStatement = ast.newExpressionStatement(assertAllCall);
// Remove all old ErrorCollector calls
for (int i = errorCollectorCalls.size() - 1; i >= 0; i--) {
ErrorCollectorCall call = errorCollectorCalls.get(i);
rewriter.remove(call.statement, group);
}
// Insert the assertAll statement where the first ErrorCollector call was
if (!errorCollectorCalls.isEmpty()) {
ErrorCollectorCall firstCall = errorCollectorCalls.get(0);
int insertIndex = statements.indexOf(firstCall.statement);
rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertAt(assertAllStatement, insertIndex,
group);
}
}
private List<ErrorCollectorCall> findErrorCollectorCalls(List<Statement> statements, String fieldName) {
List<ErrorCollectorCall> calls = new ArrayList<>();
// Use AstProcessorBuilder to find all ErrorCollector calls, including nested ones
for (Statement stmt : statements) {
ReferenceHolder<String, Object> holder = ReferenceHolder.create();
AstProcessorBuilder.with(holder)
.onMethodInvocation((invocation, h) -> {
Expression expression = invocation.getExpression();
if (expression instanceof SimpleName) {
SimpleName receiver = (SimpleName) expression;
if (fieldName.equals(receiver.getIdentifier())) {
String methodName = invocation.getName().getIdentifier();
if ("checkThat".equals(methodName) || "addError".equals(methodName)
|| "checkSucceeds".equals(methodName)) {
// Find the parent statement that contains this invocation
Statement parentStmt = findParentStatement(invocation);
if (parentStmt != null) {
calls.add(new ErrorCollectorCall(parentStmt, invocation, methodName));
}
}
}
}
return true;
})
.build(stmt);
}
return calls;
}
private Statement findParentStatement(ASTNode node) {
ASTNode current = node;
while (current != null && !(current instanceof Statement)) {
current = current.getParent();
}
return (Statement) current;
}
private LambdaExpression createLambdaForErrorCollectorCall(ErrorCollectorCall call, AST ast,
ImportRewrite importRewriter) {
LambdaExpression lambda = ast.newLambdaExpression();
lambda.setParentheses(true);
MethodInvocation invocation = call.invocation;
String methodName = call.methodName;
if ("checkThat".equals(methodName)) {
// checkThat(actual, matcher) → () -> assertThat(actual, matcher)
// Use expression-body lambda for single-expression case
// Create assertThat call with the same arguments
MethodInvocation assertThatCall = ast.newMethodInvocation();
assertThatCall.setName(ast.newSimpleName("assertThat"));
// Copy arguments
for (Object arg : invocation.arguments()) {
assertThatCall.arguments().add(ASTNode.copySubtree(ast, (ASTNode) arg));
}
// Set expression body directly (no block)
lambda.setBody(assertThatCall);
// Add Hamcrest imports for assertThat
importRewriter.addStaticImport("org.hamcrest.MatcherAssert", "assertThat", false);
} else if ("addError".equals(methodName)) {
// addError(throwable) → () -> { throw throwable; }
// This requires a block body since throw is a statement, not an expression
Block lambdaBody = ast.newBlock();
ThrowStatement throwStmt = ast.newThrowStatement();
// The argument is the throwable to throw
Expression throwableArg = (Expression) invocation.arguments().get(0);
throwStmt.setExpression((Expression) ASTNode.copySubtree(ast, throwableArg));
lambdaBody.statements().add(throwStmt);
lambda.setBody(lambdaBody);
} else if ("checkSucceeds".equals(methodName)) {
// checkSucceeds(callable) → () -> callable.call()
// Use expression-body lambda for single-expression case
// Create callable.call() invocation
Expression callableArg = (Expression) invocation.arguments().get(0);
MethodInvocation callInvocation = ast.newMethodInvocation();
callInvocation.setExpression((Expression) ASTNode.copySubtree(ast, callableArg));
callInvocation.setName(ast.newSimpleName("call"));
// Set expression body directly (no block)
lambda.setBody(callInvocation);
}
return lambda;
}
@Override
public String getPreview(boolean afterRefactoring) {
if (afterRefactoring) {
return """
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.CoreMatchers.equalTo;
import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testMultipleErrors() {
assertAll(
() -> assertThat("value1", equalTo("expected1")),
() -> assertThat("value2", equalTo("expected2")),
() -> { throw new Throwable("error message"); }
);
}
}
"""; //$NON-NLS-1$
}
return """
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ErrorCollector;
import static org.hamcrest.CoreMatchers.equalTo;
public class MyTest {
@Rule
public ErrorCollector collector = new ErrorCollector();
@Test
public void testMultipleErrors() {
collector.checkThat("value1", equalTo("expected1"));
collector.checkThat("value2", equalTo("expected2"));
collector.addError(new Throwable("error message"));
}
}
"""; //$NON-NLS-1$
}
@Override
public String toString() {
return "RuleErrorCollector"; //$NON-NLS-1$
}
/**
* Helper class to hold ErrorCollector call information
*/
private static class ErrorCollectorCall {
final Statement statement;
final MethodInvocation invocation;
final String methodName;
ErrorCollectorCall(Statement statement, MethodInvocation invocation, String methodName) {
this.statement = statement;
this.invocation = invocation;
this.methodName = methodName;
}
}
}