TestExpectedJUnitPlugin.java

/*******************************************************************************
 * Copyright (c) 2026 Carsten Hammer.
 *
 * This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License 2.0
 * which accompanies this distribution, and is available at
 * https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 * Contributors:
 *     Carsten Hammer
 *******************************************************************************/
package org.sandbox.jdt.internal.corext.fix.helper;

import static org.sandbox.jdt.internal.corext.fix.helper.lib.JUnitConstants.*;

import java.util.List;
import java.util.Set;

import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTNode;
import org.eclipse.jdt.core.dom.Block;
import org.eclipse.jdt.core.dom.CompilationUnit;
import org.eclipse.jdt.core.dom.Expression;
import org.eclipse.jdt.core.dom.ExpressionStatement;
import org.eclipse.jdt.core.dom.LambdaExpression;
import org.eclipse.jdt.core.dom.MarkerAnnotation;
import org.eclipse.jdt.core.dom.MemberValuePair;
import org.eclipse.jdt.core.dom.MethodDeclaration;
import org.eclipse.jdt.core.dom.MethodInvocation;
import org.eclipse.jdt.core.dom.NormalAnnotation;
import org.eclipse.jdt.core.dom.Statement;
import org.eclipse.jdt.core.dom.TypeLiteral;
import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
import org.eclipse.jdt.core.dom.rewrite.ImportRewrite;
import org.eclipse.jdt.internal.corext.dom.ASTNodes;
import org.eclipse.jdt.internal.corext.fix.CompilationUnitRewriteOperationsFixCore.CompilationUnitRewriteOperationWithSourceRange;
import org.eclipse.text.edits.TextEditGroup;
import org.sandbox.jdt.internal.common.HelperVisitor;
import org.sandbox.jdt.internal.common.ReferenceHolder;
import org.sandbox.jdt.internal.corext.fix.JUnitCleanUpFixCore;
import org.sandbox.jdt.internal.corext.fix.helper.lib.AbstractTool;
import org.sandbox.jdt.internal.corext.fix.helper.lib.JunitHolder;

/**
 * Plugin to migrate JUnit 4 @Test(expected=...) to JUnit 5 assertThrows().
 * 
 * Transforms:
 * <pre>
 * {@literal @}Test(expected = IllegalArgumentException.class)
 * public void testException() {
 *     // code that throws
 * }
 * </pre>
 * 
 * To:
 * <pre>
 * {@literal @}Test
 * public void testException() {
 *     assertThrows(IllegalArgumentException.class, () -> {
 *         // code that throws
 *     });
 * }
 * </pre>
 */
public class TestExpectedJUnitPlugin extends AbstractTool<ReferenceHolder<Integer, JunitHolder>> {

	@Override
	public void find(JUnitCleanUpFixCore fixcore, CompilationUnit compilationUnit,
			Set<CompilationUnitRewriteOperationWithSourceRange> operations, Set<ASTNode> nodesprocessed) {
		ReferenceHolder<Integer, JunitHolder> dataHolder = new ReferenceHolder<>();
		HelperVisitor.forAnnotation(ORG_JUNIT_TEST)
			.in(compilationUnit)
			.excluding(nodesprocessed)
			.processEach(dataHolder, (visited, aholder) -> {
				if (visited instanceof NormalAnnotation) {
					return processFoundNode(fixcore, operations, (NormalAnnotation) visited, aholder);
				}
				return true;
			});
	}

	private boolean processFoundNode(JUnitCleanUpFixCore fixcore,
			Set<CompilationUnitRewriteOperationWithSourceRange> operations, NormalAnnotation node,
			ReferenceHolder<Integer, JunitHolder> dataHolder) {
		
		// Check if this @Test annotation has an "expected" parameter
		MemberValuePair expectedPair = null;
		Expression expectedValue = null;
		
		@SuppressWarnings("unchecked")
		List<MemberValuePair> values = node.values();
		for (MemberValuePair pair : values) {
			if ("expected".equals(pair.getName().getIdentifier())) {
				expectedPair = pair;
				expectedValue = pair.getValue();
				break;
			}
		}
		
		// Only process if we found an expected parameter
		if (expectedPair != null && expectedValue != null) {
			JunitHolder mh = new JunitHolder();
			mh.minv = node;
			mh.minvname = node.getTypeName().getFullyQualifiedName();
			mh.additionalInfo = expectedPair; // Store the expected pair for removal
			dataHolder.put(dataHolder.size(), mh);
			operations.add(fixcore.rewrite(dataHolder));
		}
		
		// Return true to continue processing other annotations
		// The fluent API interprets false as "stop processing all nodes"
		return true;
	}

	@Override
	protected void process2Rewrite(TextEditGroup group, ASTRewrite rewriter, AST ast, ImportRewrite importRewriter,
			JunitHolder junitHolder) {
		NormalAnnotation testAnnotation = (NormalAnnotation) junitHolder.getAnnotation();
		MemberValuePair expectedPair = (MemberValuePair) junitHolder.additionalInfo;
		
		if (expectedPair == null) {
			return;
		}
		
		Expression expectedValue = expectedPair.getValue();
		if (!(expectedValue instanceof TypeLiteral)) {
			// Can't handle non-TypeLiteral expected values
			return;
		}
		
		TypeLiteral expectedTypeLiteral = (TypeLiteral) expectedValue;
		
		// Get the method declaration
		MethodDeclaration method = ASTNodes.getParent(testAnnotation, MethodDeclaration.class);
		if (method == null) {
			return;
		}
		
		Block methodBody = method.getBody();
		if (methodBody == null) {
			return;
		}
		
		@SuppressWarnings("unchecked")
		List<Statement> statements = methodBody.statements();
		
		// Create assertThrows method invocation
		MethodInvocation assertThrowsCall = ast.newMethodInvocation();
		assertThrowsCall.setName(ast.newSimpleName(METHOD_ASSERT_THROWS));
		
		// Add the exception class as the first argument
		TypeLiteral exceptionClass = (TypeLiteral) ASTNode.copySubtree(ast, expectedTypeLiteral);
		assertThrowsCall.arguments().add(exceptionClass);
		
		// Create lambda expression for the method body
		LambdaExpression lambda = ast.newLambdaExpression();
		lambda.setParentheses(true);
		
		Block lambdaBody = ast.newBlock();
		
		// Copy all statements from the original method body into the lambda
		for (Statement stmt : statements) {
			Statement copiedStmt = (Statement) ASTNode.copySubtree(ast, stmt);
			lambdaBody.statements().add(copiedStmt);
		}
		
		lambda.setBody(lambdaBody);
		assertThrowsCall.arguments().add(lambda);
		
		// Create the new expression statement with assertThrows
		ExpressionStatement assertThrowsStatement = ast.newExpressionStatement(assertThrowsCall);
		
		// Remove all existing statements from the method body
		for (int i = statements.size() - 1; i >= 0; i--) {
			rewriter.remove(statements.get(i), group);
		}
		
		// Add the assertThrows statement as the only statement in the method
		rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertLast(assertThrowsStatement, group);
		
		// Remove the expected parameter from @Test annotation
		// If expected is the only parameter remaining, replace with marker annotation
		@SuppressWarnings("unchecked")
		List<MemberValuePair> testValues = testAnnotation.values();
		
		// Count how many parameters will remain after removing expected
		// (need to account for other parameters that might be removed by other plugins like timeout)
		int remainingParams = 0;
		for (MemberValuePair pair : testValues) {
			String paramName = pair.getName().getIdentifier();
			// Count parameters that are not expected and not timeout (which is handled by TestTimeoutJUnitPlugin)
			if (!"expected".equals(paramName) && !"timeout".equals(paramName)) {
				remainingParams++;
			}
		}
		
		if (remainingParams == 0) {
			// No other meaningful parameters remain, convert to marker annotation @Test
			MarkerAnnotation markerTestAnnotation = ast.newMarkerAnnotation();
			markerTestAnnotation.setTypeName(ast.newSimpleName(ANNOTATION_TEST));
			ASTNodes.replaceButKeepComment(rewriter, testAnnotation, markerTestAnnotation, group);
		} else {
			// There are other parameters that need to be kept, just remove expected
			rewriter.remove(expectedPair, group);
		}
		
		// Update imports
		importRewriter.removeImport(ORG_JUNIT_TEST);
		importRewriter.addImport(ORG_JUNIT_JUPITER_TEST);
		importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, METHOD_ASSERT_THROWS, false);
	}

	@Override
	public String getPreview(boolean afterRefactoring) {
		if (afterRefactoring) {
			return """
					import static org.junit.jupiter.api.Assertions.assertThrows;
					
					import org.junit.jupiter.api.Test;
					
					@Test
					public void testException() {
						assertThrows(IllegalArgumentException.class, () -> {
							throw new IllegalArgumentException("Expected");
						});
					}
					"""; //$NON-NLS-1$
		}
		return """
				import org.junit.Test;
				
				@Test(expected = IllegalArgumentException.class)
				public void testException() {
					throw new IllegalArgumentException("Expected");
				}
				"""; //$NON-NLS-1$
	}

	@Override
	public String toString() {
		return "TestExpected"; //$NON-NLS-1$
	}
}