RuleErrorCollectorJUnitPlugin.java
/*******************************************************************************
* Copyright (c) 2026 Carsten Hammer.
*
* This program and the accompanying materials
* are made available under the terms of the Eclipse Public License 2.0
* which accompanies this distribution, and is available at
* https://www.eclipse.org/legal/epl-2.0/
*
* SPDX-License-Identifier: EPL-2.0
*
* Contributors:
* Carsten Hammer
*******************************************************************************/
package org.sandbox.jdt.internal.corext.fix.helper;
import static org.sandbox.jdt.internal.corext.fix.helper.lib.JUnitConstants.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTNode;
import org.eclipse.jdt.core.dom.ASTVisitor;
import org.eclipse.jdt.core.dom.Block;
import org.eclipse.jdt.core.dom.CompilationUnit;
import org.eclipse.jdt.core.dom.Expression;
import org.eclipse.jdt.core.dom.ExpressionStatement;
import org.eclipse.jdt.core.dom.FieldDeclaration;
import org.eclipse.jdt.core.dom.ITypeBinding;
import org.eclipse.jdt.core.dom.LambdaExpression;
import org.eclipse.jdt.core.dom.MethodDeclaration;
import org.eclipse.jdt.core.dom.MethodInvocation;
import org.eclipse.jdt.core.dom.SimpleName;
import org.eclipse.jdt.core.dom.Statement;
import org.eclipse.jdt.core.dom.ThrowStatement;
import org.eclipse.jdt.core.dom.TypeDeclaration;
import org.eclipse.jdt.core.dom.VariableDeclarationFragment;
import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
import org.eclipse.jdt.core.dom.rewrite.ImportRewrite;
import org.eclipse.jdt.internal.corext.dom.ASTNodes;
import org.eclipse.jdt.internal.corext.fix.CompilationUnitRewriteOperationsFixCore.CompilationUnitRewriteOperationWithSourceRange;
import org.eclipse.text.edits.TextEditGroup;
import org.sandbox.jdt.internal.common.HelperVisitor;
import org.sandbox.jdt.internal.common.ReferenceHolder;
import org.sandbox.jdt.internal.corext.fix.JUnitCleanUpFixCore;
import org.sandbox.jdt.internal.corext.fix.helper.lib.AbstractTool;
import org.sandbox.jdt.internal.corext.fix.helper.lib.JunitHolder;
/**
* Plugin to migrate JUnit 4 ErrorCollector rule to JUnit 5 assertAll.
*
* Transforms:
* - collector.checkThat(actual, matcher) → () -> assertThat(actual, matcher)
* - collector.addError(throwable) → () -> { throw throwable; }
* - collector.checkSucceeds(callable) → () -> callable.call()
*
* All transformations are wrapped in assertAll() per test method.
*/
public class RuleErrorCollectorJUnitPlugin extends AbstractTool<ReferenceHolder<Integer, JunitHolder>> {
@Override
public void find(JUnitCleanUpFixCore fixcore, CompilationUnit compilationUnit,
Set<CompilationUnitRewriteOperationWithSourceRange> operations, Set<ASTNode> nodesprocessed) {
ReferenceHolder<Integer, JunitHolder> dataHolder = new ReferenceHolder<>();
HelperVisitor.forField()
.withAnnotation(ORG_JUNIT_RULE)
.ofType(ORG_JUNIT_RULES_ERROR_COLLECTOR)
.in(compilationUnit)
.excluding(nodesprocessed)
.processEach(dataHolder, (visited, aholder) -> processFoundNode(fixcore, operations, (FieldDeclaration) visited, aholder));
}
private boolean processFoundNode(JUnitCleanUpFixCore fixcore,
Set<CompilationUnitRewriteOperationWithSourceRange> operations, FieldDeclaration node,
ReferenceHolder<Integer, JunitHolder> dataHolder) {
VariableDeclarationFragment fragment = (VariableDeclarationFragment) node.fragments().get(0);
if (fragment.resolveBinding() == null) {
// Return true to continue processing other fields
return true;
}
ITypeBinding binding = fragment.resolveBinding().getType();
if (binding != null && ORG_JUNIT_RULES_ERROR_COLLECTOR.equals(binding.getQualifiedName())) {
JunitHolder mh = new JunitHolder();
mh.minv = node;
dataHolder.put(dataHolder.size(), mh);
operations.add(fixcore.rewrite(dataHolder));
}
// Return true to continue processing other fields
return true;
}
@Override
protected void process2Rewrite(TextEditGroup group, ASTRewrite rewriter, AST ast, ImportRewrite importRewriter,
JunitHolder junitHolder) {
FieldDeclaration field = junitHolder.getFieldDeclaration();
TypeDeclaration parentClass = ASTNodes.getParent(field, TypeDeclaration.class);
VariableDeclarationFragment originalFragment = (VariableDeclarationFragment) field.fragments().get(0);
String fieldName = originalFragment.getName().getIdentifier();
// Remove the field declaration
rewriter.remove(field, group);
// Remove old imports
importRewriter.removeImport(ORG_JUNIT_RULE);
importRewriter.removeImport(ORG_JUNIT_RULES_ERROR_COLLECTOR);
// Add new imports
importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, "assertAll", false);
// Transform all test methods that use the ErrorCollector field
for (MethodDeclaration method : parentClass.getMethods()) {
transformTestMethod(method, fieldName, rewriter, ast, group, importRewriter);
}
}
private void transformTestMethod(MethodDeclaration method, String fieldName, ASTRewrite rewriter, AST ast,
TextEditGroup group, ImportRewrite importRewriter) {
Block methodBody = method.getBody();
if (methodBody == null) {
return;
}
List<Statement> statements = methodBody.statements();
if (statements.isEmpty()) {
return;
}
// Find all ErrorCollector method invocations in this method
List<ErrorCollectorCall> errorCollectorCalls = findErrorCollectorCalls(statements, fieldName);
if (errorCollectorCalls.isEmpty()) {
// This method doesn't use the ErrorCollector field
return;
}
// Create assertAll() call with lambda expressions for each error collector call
MethodInvocation assertAllCall = ast.newMethodInvocation();
assertAllCall.setName(ast.newSimpleName("assertAll"));
// Create lambda expressions for each ErrorCollector call
for (ErrorCollectorCall call : errorCollectorCalls) {
LambdaExpression lambda = createLambdaForErrorCollectorCall(call, ast, importRewriter);
assertAllCall.arguments().add(lambda);
}
// Create the new assertAll statement
ExpressionStatement assertAllStatement = ast.newExpressionStatement(assertAllCall);
// Remove all old ErrorCollector calls
for (int i = errorCollectorCalls.size() - 1; i >= 0; i--) {
ErrorCollectorCall call = errorCollectorCalls.get(i);
rewriter.remove(call.statement, group);
}
// Insert the assertAll statement where the first ErrorCollector call was
if (!errorCollectorCalls.isEmpty()) {
ErrorCollectorCall firstCall = errorCollectorCalls.get(0);
int insertIndex = statements.indexOf(firstCall.statement);
rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertAt(assertAllStatement, insertIndex, group);
}
}
private List<ErrorCollectorCall> findErrorCollectorCalls(List<Statement> statements, String fieldName) {
List<ErrorCollectorCall> calls = new ArrayList<>();
// Use ASTVisitor to find all ErrorCollector calls, including nested ones
for (Statement stmt : statements) {
stmt.accept(new ASTVisitor() {
@Override
public boolean visit(MethodInvocation invocation) {
Expression expression = invocation.getExpression();
if (expression instanceof SimpleName) {
SimpleName receiver = (SimpleName) expression;
if (fieldName.equals(receiver.getIdentifier())) {
String methodName = invocation.getName().getIdentifier();
if ("checkThat".equals(methodName) || "addError".equals(methodName) || "checkSucceeds".equals(methodName)) {
// Find the parent statement that contains this invocation
Statement parentStmt = findParentStatement(invocation);
if (parentStmt != null) {
calls.add(new ErrorCollectorCall(parentStmt, invocation, methodName));
}
}
}
}
return super.visit(invocation);
}
});
}
return calls;
}
private Statement findParentStatement(ASTNode node) {
ASTNode current = node;
while (current != null && !(current instanceof Statement)) {
current = current.getParent();
}
return (Statement) current;
}
private LambdaExpression createLambdaForErrorCollectorCall(ErrorCollectorCall call, AST ast, ImportRewrite importRewriter) {
LambdaExpression lambda = ast.newLambdaExpression();
lambda.setParentheses(true);
MethodInvocation invocation = call.invocation;
String methodName = call.methodName;
if ("checkThat".equals(methodName)) {
// checkThat(actual, matcher) → () -> assertThat(actual, matcher)
// Use expression-body lambda for single-expression case
// Create assertThat call with the same arguments
MethodInvocation assertThatCall = ast.newMethodInvocation();
assertThatCall.setName(ast.newSimpleName("assertThat"));
// Copy arguments
for (Object arg : invocation.arguments()) {
assertThatCall.arguments().add(ASTNode.copySubtree(ast, (ASTNode) arg));
}
// Set expression body directly (no block)
lambda.setBody(assertThatCall);
// Add Hamcrest imports for assertThat
importRewriter.addStaticImport("org.hamcrest.MatcherAssert", "assertThat", false);
} else if ("addError".equals(methodName)) {
// addError(throwable) → () -> { throw throwable; }
// This requires a block body since throw is a statement, not an expression
Block lambdaBody = ast.newBlock();
ThrowStatement throwStmt = ast.newThrowStatement();
// The argument is the throwable to throw
Expression throwableArg = (Expression) invocation.arguments().get(0);
throwStmt.setExpression((Expression) ASTNode.copySubtree(ast, throwableArg));
lambdaBody.statements().add(throwStmt);
lambda.setBody(lambdaBody);
} else if ("checkSucceeds".equals(methodName)) {
// checkSucceeds(callable) → () -> callable.call()
// Use expression-body lambda for single-expression case
// Create callable.call() invocation
Expression callableArg = (Expression) invocation.arguments().get(0);
MethodInvocation callInvocation = ast.newMethodInvocation();
callInvocation.setExpression((Expression) ASTNode.copySubtree(ast, callableArg));
callInvocation.setName(ast.newSimpleName("call"));
// Set expression body directly (no block)
lambda.setBody(callInvocation);
}
return lambda;
}
@Override
public String getPreview(boolean afterRefactoring) {
if (afterRefactoring) {
return """
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.CoreMatchers.equalTo;
import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testMultipleErrors() {
assertAll(
() -> assertThat("value1", equalTo("expected1")),
() -> assertThat("value2", equalTo("expected2")),
() -> { throw new Throwable("error message"); }
);
}
}
"""; //$NON-NLS-1$
}
return """
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ErrorCollector;
import static org.hamcrest.CoreMatchers.equalTo;
public class MyTest {
@Rule
public ErrorCollector collector = new ErrorCollector();
@Test
public void testMultipleErrors() {
collector.checkThat("value1", equalTo("expected1"));
collector.checkThat("value2", equalTo("expected2"));
collector.addError(new Throwable("error message"));
}
}
"""; //$NON-NLS-1$
}
@Override
public String toString() {
return "RuleErrorCollector"; //$NON-NLS-1$
}
/**
* Helper class to hold ErrorCollector call information
*/
private static class ErrorCollectorCall {
final Statement statement;
final MethodInvocation invocation;
final String methodName;
ErrorCollectorCall(Statement statement, MethodInvocation invocation, String methodName) {
this.statement = statement;
this.invocation = invocation;
this.methodName = methodName;
}
}
}