RuleErrorCollectorJUnitPlugin.java

/*******************************************************************************
 * Copyright (c) 2026 Carsten Hammer.
 *
 * This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License 2.0
 * which accompanies this distribution, and is available at
 * https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 * Contributors:
 *     Carsten Hammer
 *******************************************************************************/
package org.sandbox.jdt.internal.corext.fix.helper;

import static org.sandbox.jdt.internal.corext.fix.helper.lib.JUnitConstants.*;

import java.util.ArrayList;
import java.util.List;

import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTNode;
import org.eclipse.jdt.core.dom.Block;
import org.eclipse.jdt.core.dom.Expression;
import org.eclipse.jdt.core.dom.ExpressionStatement;
import org.eclipse.jdt.core.dom.FieldDeclaration;
import org.eclipse.jdt.core.dom.ITypeBinding;
import org.eclipse.jdt.core.dom.LambdaExpression;
import org.eclipse.jdt.core.dom.MethodDeclaration;
import org.eclipse.jdt.core.dom.MethodInvocation;
import org.eclipse.jdt.core.dom.SimpleName;
import org.eclipse.jdt.core.dom.Statement;
import org.eclipse.jdt.core.dom.ThrowStatement;
import org.eclipse.jdt.core.dom.TypeDeclaration;
import org.eclipse.jdt.core.dom.VariableDeclarationFragment;
import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
import org.eclipse.jdt.core.dom.rewrite.ImportRewrite;
import org.eclipse.jdt.internal.corext.dom.ASTNodes;
import org.eclipse.text.edits.TextEditGroup;
import org.sandbox.jdt.internal.common.AstProcessorBuilder;
import org.sandbox.jdt.internal.common.ReferenceHolder;
import org.sandbox.jdt.internal.corext.fix.helper.lib.JunitHolder;
import org.sandbox.jdt.internal.corext.fix.helper.lib.TriggerPatternCleanupPlugin;
import org.sandbox.jdt.triggerpattern.api.CleanupPattern;
import org.sandbox.jdt.triggerpattern.api.Match;
import org.sandbox.jdt.triggerpattern.api.PatternKind;

/**
 * Plugin to migrate JUnit 4 ErrorCollector rule to JUnit 5 assertAll.
 * 
 * Transforms: - collector.checkThat(actual, matcher) → () -> assertThat(actual,
 * matcher) - collector.addError(throwable) → () -> { throw throwable; } -
 * collector.checkSucceeds(callable) → () -> callable.call()
 * 
 * All transformations are wrapped in assertAll() per test method.
 *
 * @since 1.3.0
 */
@CleanupPattern(value = "@Rule public ErrorCollector $name", kind = PatternKind.FIELD, qualifiedType = ORG_JUNIT_RULES_ERROR_COLLECTOR, cleanupId = "cleanup.junit.ruleerrorcollector", description = "Migrate @Rule ErrorCollector to assertAll()", displayName = "JUnit 4 @Rule ErrorCollector \u2192 JUnit 5 assertAll()")
public class RuleErrorCollectorJUnitPlugin extends TriggerPatternCleanupPlugin {

	@Override
	protected JunitHolder createHolder(Match match) {
		FieldDeclaration fieldDecl = (FieldDeclaration) match.getMatchedNode();
		VariableDeclarationFragment fragment = (VariableDeclarationFragment) fieldDecl.fragments().get(0);
		ITypeBinding binding = fragment.resolveBinding() != null ? fragment.resolveBinding().getType() : null;
		boolean isErrorCollector;
		if (binding != null && ORG_JUNIT_RULES_ERROR_COLLECTOR.equals(binding.getQualifiedName())) {
			isErrorCollector = true;
		} else {
			// Fallback: check by source type name when binding is unavailable or recovered
			String typeName = fieldDecl.getType().toString();
			String simpleTypeName = ORG_JUNIT_RULES_ERROR_COLLECTOR
					.substring(ORG_JUNIT_RULES_ERROR_COLLECTOR.lastIndexOf('.') + 1);
			isErrorCollector = simpleTypeName.equals(typeName) || ORG_JUNIT_RULES_ERROR_COLLECTOR.equals(typeName);
		}
		if (!isErrorCollector) {
			return null;
		}
		JunitHolder holder = new JunitHolder();
		holder.setMinv(fieldDecl);
		return holder;
	}

	@Override
	protected void process2Rewrite(TextEditGroup group, ASTRewrite rewriter, AST ast, ImportRewrite importRewriter,
			JunitHolder junitHolder) {
		FieldDeclaration field = junitHolder.getFieldDeclaration();
		TypeDeclaration parentClass = ASTNodes.getParent(field, TypeDeclaration.class);

		VariableDeclarationFragment originalFragment = (VariableDeclarationFragment) field.fragments().get(0);
		String fieldName = originalFragment.getName().getIdentifier();

		// Remove the field declaration
		rewriter.remove(field, group);

		// Remove old imports
		importRewriter.removeImport(ORG_JUNIT_RULE);
		importRewriter.removeImport(ORG_JUNIT_RULES_ERROR_COLLECTOR);

		// Add new imports
		importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, "assertAll", false);

		// Transform all test methods that use the ErrorCollector field
		for (MethodDeclaration method : parentClass.getMethods()) {
			transformTestMethod(method, fieldName, rewriter, ast, group, importRewriter);
		}
	}

	private void transformTestMethod(MethodDeclaration method, String fieldName, ASTRewrite rewriter, AST ast,
			TextEditGroup group, ImportRewrite importRewriter) {
		Block methodBody = method.getBody();
		if (methodBody == null) {
			return;
		}

		List<Statement> statements = methodBody.statements();
		if (statements.isEmpty()) {
			return;
		}

		// Find all ErrorCollector method invocations in this method
		List<ErrorCollectorCall> errorCollectorCalls = findErrorCollectorCalls(statements, fieldName);

		if (errorCollectorCalls.isEmpty()) {
			// This method doesn't use the ErrorCollector field
			return;
		}

		// Create assertAll() call with lambda expressions for each error collector call
		MethodInvocation assertAllCall = ast.newMethodInvocation();
		assertAllCall.setName(ast.newSimpleName("assertAll"));

		// Create lambda expressions for each ErrorCollector call
		for (ErrorCollectorCall call : errorCollectorCalls) {
			LambdaExpression lambda = createLambdaForErrorCollectorCall(call, ast, importRewriter);
			assertAllCall.arguments().add(lambda);
		}

		// Create the new assertAll statement
		ExpressionStatement assertAllStatement = ast.newExpressionStatement(assertAllCall);

		// Remove all old ErrorCollector calls
		for (int i = errorCollectorCalls.size() - 1; i >= 0; i--) {
			ErrorCollectorCall call = errorCollectorCalls.get(i);
			rewriter.remove(call.statement, group);
		}

		// Insert the assertAll statement where the first ErrorCollector call was
		if (!errorCollectorCalls.isEmpty()) {
			ErrorCollectorCall firstCall = errorCollectorCalls.get(0);
			int insertIndex = statements.indexOf(firstCall.statement);
			rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertAt(assertAllStatement, insertIndex,
					group);
		}
	}

	private List<ErrorCollectorCall> findErrorCollectorCalls(List<Statement> statements, String fieldName) {
		List<ErrorCollectorCall> calls = new ArrayList<>();

		// Use AstProcessorBuilder to find all ErrorCollector calls, including nested ones
		for (Statement stmt : statements) {
			ReferenceHolder<String, Object> holder = ReferenceHolder.create();
			AstProcessorBuilder.with(holder)
				.onMethodInvocation((invocation, h) -> {
					Expression expression = invocation.getExpression();
					if (expression instanceof SimpleName) {
						SimpleName receiver = (SimpleName) expression;
						if (fieldName.equals(receiver.getIdentifier())) {
							String methodName = invocation.getName().getIdentifier();
							if ("checkThat".equals(methodName) || "addError".equals(methodName)
									|| "checkSucceeds".equals(methodName)) {
								// Find the parent statement that contains this invocation
								Statement parentStmt = findParentStatement(invocation);
								if (parentStmt != null) {
									calls.add(new ErrorCollectorCall(parentStmt, invocation, methodName));
								}
							}
						}
					}
					return true;
				})
				.build(stmt);
		}

		return calls;
	}

	private Statement findParentStatement(ASTNode node) {
		ASTNode current = node;
		while (current != null && !(current instanceof Statement)) {
			current = current.getParent();
		}
		return (Statement) current;
	}

	private LambdaExpression createLambdaForErrorCollectorCall(ErrorCollectorCall call, AST ast,
			ImportRewrite importRewriter) {
		LambdaExpression lambda = ast.newLambdaExpression();
		lambda.setParentheses(true);

		MethodInvocation invocation = call.invocation;
		String methodName = call.methodName;

		if ("checkThat".equals(methodName)) {
			// checkThat(actual, matcher) → () -> assertThat(actual, matcher)
			// Use expression-body lambda for single-expression case

			// Create assertThat call with the same arguments
			MethodInvocation assertThatCall = ast.newMethodInvocation();
			assertThatCall.setName(ast.newSimpleName("assertThat"));

			// Copy arguments
			for (Object arg : invocation.arguments()) {
				assertThatCall.arguments().add(ASTNode.copySubtree(ast, (ASTNode) arg));
			}

			// Set expression body directly (no block)
			lambda.setBody(assertThatCall);

			// Add Hamcrest imports for assertThat
			importRewriter.addStaticImport("org.hamcrest.MatcherAssert", "assertThat", false);
		} else if ("addError".equals(methodName)) {
			// addError(throwable) → () -> { throw throwable; }
			// This requires a block body since throw is a statement, not an expression
			Block lambdaBody = ast.newBlock();

			ThrowStatement throwStmt = ast.newThrowStatement();
			// The argument is the throwable to throw
			Expression throwableArg = (Expression) invocation.arguments().get(0);
			throwStmt.setExpression((Expression) ASTNode.copySubtree(ast, throwableArg));

			lambdaBody.statements().add(throwStmt);
			lambda.setBody(lambdaBody);
		} else if ("checkSucceeds".equals(methodName)) {
			// checkSucceeds(callable) → () -> callable.call()
			// Use expression-body lambda for single-expression case

			// Create callable.call() invocation
			Expression callableArg = (Expression) invocation.arguments().get(0);
			MethodInvocation callInvocation = ast.newMethodInvocation();
			callInvocation.setExpression((Expression) ASTNode.copySubtree(ast, callableArg));
			callInvocation.setName(ast.newSimpleName("call"));

			// Set expression body directly (no block)
			lambda.setBody(callInvocation);
		}

		return lambda;
	}

	@Override
	public String getPreview(boolean afterRefactoring) {
		if (afterRefactoring) {
			return """
					import static org.junit.jupiter.api.Assertions.assertAll;
					import static org.hamcrest.MatcherAssert.assertThat;
					import static org.hamcrest.CoreMatchers.equalTo;

					import org.junit.jupiter.api.Test;

					public class MyTest {
						@Test
						public void testMultipleErrors() {
							assertAll(
								() -> assertThat("value1", equalTo("expected1")),
								() -> assertThat("value2", equalTo("expected2")),
								() -> { throw new Throwable("error message"); }
							);
						}
					}
					"""; //$NON-NLS-1$
		}
		return """
				import org.junit.Rule;
				import org.junit.Test;
				import org.junit.rules.ErrorCollector;
				import static org.hamcrest.CoreMatchers.equalTo;

				public class MyTest {
					@Rule
					public ErrorCollector collector = new ErrorCollector();

					@Test
					public void testMultipleErrors() {
						collector.checkThat("value1", equalTo("expected1"));
						collector.checkThat("value2", equalTo("expected2"));
						collector.addError(new Throwable("error message"));
					}
				}
				"""; //$NON-NLS-1$
	}

	@Override
	public String toString() {
		return "RuleErrorCollector"; //$NON-NLS-1$
	}

	/**
	 * Helper class to hold ErrorCollector call information
	 */
	private static class ErrorCollectorCall {
		final Statement statement;
		final MethodInvocation invocation;
		final String methodName;

		ErrorCollectorCall(Statement statement, MethodInvocation invocation, String methodName) {
			this.statement = statement;
			this.invocation = invocation;
			this.methodName = methodName;
		}
	}
}