TestJUnit3Plugin.java

/*******************************************************************************
 * Copyright (c) 2021 Carsten Hammer.
 *
 * This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License 2.0
 * which accompanies this distribution, and is available at
 * https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 * Contributors:
 *     Carsten Hammer
 *******************************************************************************/
package org.sandbox.jdt.internal.corext.fix.helper;

import static org.sandbox.jdt.internal.corext.fix.helper.lib.JUnitConstants.*;

import java.util.Set;

import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTNode;
import org.eclipse.jdt.core.dom.Annotation;
import org.eclipse.jdt.core.dom.CompilationUnit;
import org.eclipse.jdt.core.dom.MarkerAnnotation;
import org.eclipse.jdt.core.dom.MethodDeclaration;
import org.eclipse.jdt.core.dom.MethodInvocation;
import org.eclipse.jdt.core.dom.Modifier;
import org.eclipse.jdt.core.dom.PrimitiveType;
import org.eclipse.jdt.core.dom.SimpleName;
import org.eclipse.jdt.core.dom.Type;
import org.eclipse.jdt.core.dom.TypeDeclaration;
import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
import org.eclipse.jdt.core.dom.rewrite.ImportRewrite;
import org.eclipse.jdt.core.dom.rewrite.ListRewrite;
import org.eclipse.jdt.internal.corext.fix.CompilationUnitRewriteOperationsFixCore.CompilationUnitRewriteOperationWithSourceRange;
import org.eclipse.text.edits.TextEditGroup;
import org.sandbox.jdt.internal.corext.util.AnnotationUtils;
import org.sandbox.jdt.internal.common.AstProcessorBuilder;
import org.sandbox.jdt.internal.common.HelperVisitorFactory;
import org.sandbox.jdt.internal.common.ReferenceHolder;
import org.sandbox.jdt.internal.corext.fix.JUnitCleanUpFixCore;
import org.sandbox.jdt.internal.corext.fix.helper.lib.AbstractTool;
import org.sandbox.jdt.internal.corext.fix.helper.lib.JunitHolder;

/**
 * Plugin to migrate JUnit 3 TestCase classes to JUnit 5.
 */
public class TestJUnit3Plugin extends AbstractTool<ReferenceHolder<Integer, JunitHolder>> {

	@Override
	public void find(JUnitCleanUpFixCore fixcore, CompilationUnit compilationUnit,
			Set<CompilationUnitRewriteOperationWithSourceRange> operations, Set<ASTNode> nodesprocessed) {
		ReferenceHolder<Integer, JunitHolder> dataHolder = ReferenceHolder.createIndexed();
		HelperVisitorFactory.callTypeDeclarationVisitor("junit.framework.TestCase", compilationUnit, dataHolder,
				nodesprocessed,
				(visited, aholder) -> processFoundNode(fixcore, operations, visited, aholder, nodesprocessed));
	}

	private boolean processFoundNode(JUnitCleanUpFixCore fixcore,
			Set<CompilationUnitRewriteOperationWithSourceRange> operations, TypeDeclaration node,
			ReferenceHolder<Integer, JunitHolder> dataHolder, Set<ASTNode> nodesprocessed) {
		if (!nodesprocessed.contains(node)) {
			boolean hasLifecycleMethod = false;
			for (MethodDeclaration method : node.getMethods()) {
				if (!isTestMethod(method)) {
					hasLifecycleMethod = true;
					break;
				}
			}
			if (!hasLifecycleMethod) {
				return false;
			}

			nodesprocessed.add(node);
			JunitHolder mh = new JunitHolder();
			mh.setMinv(node);
			dataHolder.put(dataHolder.size(), mh);
			operations.add(fixcore.rewrite(dataHolder));
		}
		return false;
	}

	private boolean isTestMethod(MethodDeclaration method) {
		// Exclude constructors
		if (method.isConstructor()) {
			return false;
		}

		String methodName = method.getName().getIdentifier();

		// Check for typical JUnit 3 test methods
		if (methodName.startsWith("test")) {
			return true;
		}

		// Check for alternative naming schemes
		if (methodName.endsWith("_test") || methodName.startsWith("should") || methodName.contains("Test")) {
			return true;
		}

		// Additional conditions: public, void, no parameters
		Type returnType = method.getReturnType2();
		return Modifier.isPublic(method.getModifiers()) && returnType != null && "void".equals(returnType.toString())
				&& method.parameters().isEmpty();
	}

	@Override
	protected void process2Rewrite(TextEditGroup group, ASTRewrite rewriter, AST ast, ImportRewrite importRewriter,
			JunitHolder junitHolder) {
		TypeDeclaration node = junitHolder.getTypeDeclaration();
		// Remove `extends TestCase`
		Type superclass = node.getSuperclassType();
		if (superclass != null && "TestCase".equals(superclass.toString())) {
			rewriter.remove(node.getSuperclassType(), group);
			importRewriter.removeImport("junit.framework.TestCase"); //$NON-NLS-1$
		}

		for (MethodDeclaration method : node.getMethods()) {
			if (isSetupMethod(method)) {
				convertToAnnotation(method, "BeforeEach", importRewriter, rewriter, ast, group);
			} else if (isTeardownMethod(method)) {
				convertToAnnotation(method, "AfterEach", importRewriter, rewriter, ast, group);
			} else if (isTestMethod(method)) {
				addAnnotationToMethod(method, "Test", importRewriter, rewriter, ast, group);
			}

			// Process assertions and assumptions in all relevant methods
			if (method.getBody() != null) {
				rewriteAssertionsAndAssumptions(method, rewriter, ast, group, importRewriter);
			}
		}

	}

	private static final Set<String> KNOWN_JUNIT3_ASSERTION_METHODS = Set.of(
			"assertEquals", "assertArrayEquals", "assertTrue", "assertFalse", //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$ //$NON-NLS-4$
			"assertNull", "assertNotNull", "assertSame", "assertNotSame", "fail"); //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$ //$NON-NLS-4$ //$NON-NLS-5$

	private void rewriteAssertionsAndAssumptions(MethodDeclaration method, ASTRewrite rewriter, AST ast,
			TextEditGroup group, ImportRewrite importRewriter) {
		ReferenceHolder<String, Object> holder = ReferenceHolder.create();
		AstProcessorBuilder.with(holder)
			.onMethodInvocation((node, h) -> {
				boolean isJunitFrameworkAssertion = false;
				// Try binding-based resolution first
				if (node.resolveMethodBinding() != null) {
					String fullyQualifiedName = node.resolveMethodBinding().getDeclaringClass().getQualifiedName();
					isJunitFrameworkAssertion = "junit.framework.Assert".equals(fullyQualifiedName)
							|| "junit.framework.TestCase".equals(fullyQualifiedName)
							|| "junit.framework.Assume".equals(fullyQualifiedName);
				} else if (node.getExpression() == null) {
					// Fallback: unqualified call in TestCase subclass - check by method name
					isJunitFrameworkAssertion = KNOWN_JUNIT3_ASSERTION_METHODS
							.contains(node.getName().getIdentifier());
				}

				if (isJunitFrameworkAssertion) {
					reorderParameters(node, rewriter, group, ONEPARAM_ASSERTIONS, TWOPARAM_ASSERTIONS);

					// Update qualifier (e.g., Assert.assertEquals -> Assertions.assertEquals)
					if (node.getExpression() != null) {
						rewriter.set(node.getExpression(), SimpleName.IDENTIFIER_PROPERTY, "Assertions", group);
					} else {
						// Unqualified call (e.g., inherited from TestCase) - add qualifier
						rewriter.set(node, MethodInvocation.EXPRESSION_PROPERTY,
								ast.newSimpleName("Assertions"), group);
					}

					// Update imports
					addImportForAssertion(node.getName().getIdentifier(), importRewriter);
				}
				return true;
			})
			.build(method);
	}

	private void addImportForAssertion(String assertionMethod, ImportRewrite importRewriter) {
		String importToAdd = null;

		switch (assertionMethod) {
		case "assertEquals":
		case "assertArrayEquals":
		case "assertTrue":
		case "assertFalse":
		case "assertNull":
		case "assertNotNull":
		case "assertSame":
		case "assertNotSame":
		case "fail":
			importToAdd = ORG_JUNIT_JUPITER_API_ASSERTIONS;
			break;
		case "assumeTrue":
		case "assumeFalse":
		case "assumeNotNull":
			importToAdd = ORG_JUNIT_JUPITER_API_ASSUMPTIONS;
			break;
		case "assertThat":
			importToAdd = ORG_HAMCREST_MATCHER_ASSERT;
			break;
		default:
			break;
		}

		if (importToAdd != null) {
			importRewriter.addImport(importToAdd);
		}
	}

	private boolean isSetupMethod(MethodDeclaration method) {
		return "setUp".equals(method.getName().getIdentifier()) && method.parameters().isEmpty()
				&& isVoidReturnType(method);
	}

	private boolean isTeardownMethod(MethodDeclaration method) {
		return "tearDown".equals(method.getName().getIdentifier()) && method.parameters().isEmpty()
				&& isVoidReturnType(method);
	}

	private boolean isVoidReturnType(MethodDeclaration method) {
		Type returnType = method.getReturnType2();
		return returnType != null && returnType.isPrimitiveType()
				&& PrimitiveType.VOID.equals(((PrimitiveType) returnType).getPrimitiveTypeCode());
	}

	private void convertToAnnotation(MethodDeclaration method, String annotation, ImportRewrite importRewrite,
			ASTRewrite rewrite, AST ast, TextEditGroup group) {
		ListRewrite modifiers = rewrite.getListRewrite(method, MethodDeclaration.MODIFIERS2_PROPERTY);
		// Remove @Override since the superclass (TestCase) is being removed
		removeOverrideAnnotation(method, rewrite, group);
		MarkerAnnotation newMarkerAnnotation = AnnotationUtils.createMarkerAnnotation(ast, annotation);
		modifiers.insertFirst(newMarkerAnnotation, group);
		importRewrite.addImport("org.junit.jupiter.api." + annotation);
	}

	private void removeOverrideAnnotation(MethodDeclaration method, ASTRewrite rewrite, TextEditGroup group) {
		for (Object modifier : method.modifiers()) {
			if (modifier instanceof Annotation annotation
					&& "Override".equals(annotation.getTypeName().getFullyQualifiedName())) {
				rewrite.remove(annotation, group);
				break;
			}
		}
	}

	private void addAnnotationToMethod(MethodDeclaration method, String annotation, ImportRewrite importRewrite,
			ASTRewrite rewrite, AST ast, TextEditGroup group) {
		ListRewrite modifiers = rewrite.getListRewrite(method, MethodDeclaration.MODIFIERS2_PROPERTY);
		MarkerAnnotation newMarkerAnnotation = AnnotationUtils.createMarkerAnnotation(ast, annotation);
		modifiers.insertFirst(newMarkerAnnotation, group);
		importRewrite.addImport("org.junit.jupiter.api." + annotation);
	}

	@Override
	public String getPreview(boolean afterRefactoring) {
		if (afterRefactoring) {
			return """
					import org.junit.jupiter.api.Test;
					"""; //$NON-NLS-1$
		}
		return """
				import junit.framework.TestCase;
				"""; //$NON-NLS-1$
	}

	@Override
	public String toString() {
		return "TestCase"; //$NON-NLS-1$
	}
}