AssertOptimizationJUnitPlugin.java

/*******************************************************************************
 * Copyright (c) 2026 Carsten Hammer.
 *
 * This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License 2.0
 * which accompanies this distribution, and is available at
 * https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 * Contributors:
 *     Carsten Hammer
 *******************************************************************************/
package org.sandbox.jdt.internal.corext.fix.helper;

import static org.sandbox.jdt.internal.corext.fix.helper.lib.JUnitConstants.*;

import java.util.List;
import java.util.Set;

import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTNode;
import org.eclipse.jdt.core.dom.ArrayCreation;
import org.eclipse.jdt.core.dom.ArrayInitializer;
import org.eclipse.jdt.core.dom.BooleanLiteral;
import org.eclipse.jdt.core.dom.CharacterLiteral;
import org.eclipse.jdt.core.dom.CompilationUnit;
import org.eclipse.jdt.core.dom.Expression;
import org.eclipse.jdt.core.dom.FieldAccess;
import org.eclipse.jdt.core.dom.ITypeBinding;
import org.eclipse.jdt.core.dom.IVariableBinding;
import org.eclipse.jdt.core.dom.InfixExpression;
import org.eclipse.jdt.core.dom.MethodInvocation;
import org.eclipse.jdt.core.dom.Modifier;
import org.eclipse.jdt.core.dom.NullLiteral;
import org.eclipse.jdt.core.dom.NumberLiteral;
import org.eclipse.jdt.core.dom.PrefixExpression;
import org.eclipse.jdt.core.dom.QualifiedName;
import org.eclipse.jdt.core.dom.SimpleName;
import org.eclipse.jdt.core.dom.StringLiteral;
import org.eclipse.jdt.core.dom.TypeLiteral;
import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
import org.eclipse.jdt.core.dom.rewrite.ImportRewrite;
import org.eclipse.jdt.core.dom.rewrite.ListRewrite;
import org.eclipse.jdt.internal.corext.fix.CompilationUnitRewriteOperationsFixCore.CompilationUnitRewriteOperationWithSourceRange;
import org.eclipse.text.edits.TextEditGroup;
import org.sandbox.jdt.internal.common.HelperVisitor;
import org.sandbox.jdt.internal.common.ReferenceHolder;
import org.sandbox.jdt.internal.corext.fix.JUnitCleanUpFixCore;
import org.sandbox.jdt.internal.corext.fix.helper.lib.AbstractTool;
import org.sandbox.jdt.internal.corext.fix.helper.lib.JunitHolder;

/**
 * Optimizes JUnit assertions by converting generic assertions to more specific ones
 * and correcting parameter order (expected/actual).
 * 
 * Examples:
 * - assertTrue(a == b) → assertEquals(a, b)
 * - assertTrue(obj == null) → assertNull(obj)
 * - assertTrue(!condition) → assertFalse(condition)
 * - assertTrue(a.equals(b)) → assertEquals(a, b)
 * - assertEquals(getActual(), EXPECTED) → assertEquals(EXPECTED, getActual())
 */
public class AssertOptimizationJUnitPlugin extends AbstractTool<ReferenceHolder<Integer, JunitHolder>> {

	/**
	 * Assertion methods that have expected/actual parameter order.
	 * First parameter should be expected (constant/literal), second should be actual (computed).
	 */
	private static final Set<String> METHODS_WITH_EXPECTED_ACTUAL = Set.of(
		"assertEquals",
		"assertNotEquals",
		"assertArrayEquals",
		"assertSame",
		"assertNotSame",
		"assertIterableEquals",  // JUnit 5 only
		"assertLinesMatch"       // JUnit 5 only
	);

	@Override
	public void find(JUnitCleanUpFixCore fixcore, CompilationUnit compilationUnit,
			Set<CompilationUnitRewriteOperationWithSourceRange> operations, Set<ASTNode> nodesprocessed) {
		ReferenceHolder<Integer, JunitHolder> dataHolder = new ReferenceHolder<>();
		
		// NOTE: We only process JUnit 5 (Assertions) calls here.
		// JUnit 4 (Assert) calls are handled by AssertJUnitPlugin which does migration.
		// The optimization for JUnit 4 assertions should be done within the migration itself.
		
		// Find assertTrue and assertFalse calls for optimization (JUnit 5)
		HelperVisitor.forMethodCalls(ORG_JUNIT_JUPITER_API_ASSERTIONS, Set.of("assertTrue", "assertFalse"))
			.in(compilationUnit)
			.excluding(nodesprocessed)
			.processEach(dataHolder, (visited, aholder) -> {
				if (visited instanceof MethodInvocation mi) {
					boolean isTrue = "assertTrue".equals(mi.getName().getIdentifier());
					return processAssertion(fixcore, operations, visited, aholder, isTrue);
				}
				return true;
			});
		
		// Find assertion calls with expected/actual parameters for parameter order correction (JUnit 5)
		HelperVisitor.forMethodCalls(ORG_JUNIT_JUPITER_API_ASSERTIONS, METHODS_WITH_EXPECTED_ACTUAL)
			.in(compilationUnit)
			.excluding(nodesprocessed)
			.processEach(dataHolder, (visited, aholder) -> {
				if (visited instanceof MethodInvocation) {
					return processParameterOrder(fixcore, operations, visited, aholder);
				}
				return true;
			});
	}

	private boolean processAssertion(JUnitCleanUpFixCore fixcore,
			Set<CompilationUnitRewriteOperationWithSourceRange> operations, ASTNode node,
			ReferenceHolder<Integer, JunitHolder> dataHolder, boolean isTrue) {
		
		if (!(node instanceof MethodInvocation)) {
			return true; // Continue processing other nodes
		}
		
		MethodInvocation mi = (MethodInvocation) node;
		List<?> arguments = mi.arguments();
		
		if (arguments.isEmpty()) {
			return true; // Continue processing other nodes
		}
		
		// Get the condition expression (may be first or second argument depending on whether message is present)
		Expression condition = null;
		Expression message = null;
		
		if (arguments.size() == 1) {
			condition = (Expression) arguments.get(0);
		} else if (arguments.size() == 2) {
			// Check if first argument is String (message), otherwise it's condition
			Expression firstArg = (Expression) arguments.get(0);
			ITypeBinding firstArgType = firstArg.resolveTypeBinding();
			if (firstArgType != null && "java.lang.String".equals(firstArgType.getQualifiedName())) {
				message = firstArg;
				condition = (Expression) arguments.get(1);
			} else {
				condition = firstArg;
				message = (Expression) arguments.get(1);
			}
		}
		
		if (condition == null || !canOptimize(condition)) {
			return true; // Continue processing other nodes
		}
		
		return addStandardRewriteOperation(fixcore, operations, node, dataHolder);
	}

	/**
	 * Processes assertion method calls to check if parameters need to be swapped.
	 * Swaps parameters if the second parameter is a constant but the first is not.
	 */
	private boolean processParameterOrder(JUnitCleanUpFixCore fixcore,
			Set<CompilationUnitRewriteOperationWithSourceRange> operations, ASTNode node,
			ReferenceHolder<Integer, JunitHolder> dataHolder) {
		
		if (!(node instanceof MethodInvocation)) {
			return true; // Continue processing other nodes
		}
		
		MethodInvocation mi = (MethodInvocation) node;
		List<?> arguments = mi.arguments();
		
		// Need at least 2 arguments (expected, actual) or 3 with message
		if (arguments.size() < 2) {
			return true; // Continue processing other nodes
		}
		
		// Get first two arguments (they might be expected/actual or message/expected depending on JUnit version)
		Expression first = (Expression) arguments.get(0);
		Expression second = (Expression) arguments.get(1);
		
		// Check if first argument is a String message (JUnit 4 style)
		ITypeBinding firstType = first.resolveTypeBinding();
		boolean firstIsMessage = firstType != null && "java.lang.String".equals(firstType.getQualifiedName());
		
		Expression expectedParam;
		Expression actualParam;
		
		if (firstIsMessage && arguments.size() >= 3) {
			// JUnit 4: message, expected, actual
			expectedParam = (Expression) arguments.get(1);
			actualParam = (Expression) arguments.get(2);
		} else {
			// JUnit 5: expected, actual [, message]
			expectedParam = first;
			actualParam = second;
		}
		
		// If expected is not constant but actual is constant, we need to swap
		if (!isConstantExpression(expectedParam) && isConstantExpression(actualParam)) {
			return addStandardRewriteOperation(fixcore, operations, node, dataHolder);
		}
		
		return true; // Continue processing other nodes
	}

	/**
	 * Determines if an expression is a constant value.
	 * Constants include literals, final fields, enums, class literals, array literals,
	 * and collection factory methods with constant arguments.
	 */
	private boolean isConstantExpression(Expression expr) {
		if (expr == null) {
			return false;
		}
		
		// Literals
		if (expr instanceof NumberLiteral || expr instanceof StringLiteral || 
			expr instanceof BooleanLiteral || expr instanceof CharacterLiteral ||
			expr instanceof NullLiteral || expr instanceof TypeLiteral) {
			return true;
		}
		
		// Final fields and enum constants
		if (expr instanceof SimpleName) {
			SimpleName name = (SimpleName) expr;
			IVariableBinding binding = (IVariableBinding) name.resolveBinding();
			if (binding != null && binding.isField()) {
				int modifiers = binding.getModifiers();
				return Modifier.isFinal(modifiers) || Modifier.isStatic(modifiers) || binding.isEnumConstant();
			}
		}
		
		// Qualified names (e.g., MyClass.CONSTANT)
		if (expr instanceof QualifiedName) {
			QualifiedName qname = (QualifiedName) expr;
			IVariableBinding binding = (IVariableBinding) qname.resolveBinding();
			if (binding != null && binding.isField()) {
				int modifiers = binding.getModifiers();
				return Modifier.isFinal(modifiers) || Modifier.isStatic(modifiers) || binding.isEnumConstant();
			}
		}
		
		// Field access (e.g., Status.ACTIVE)
		if (expr instanceof FieldAccess) {
			FieldAccess fieldAccess = (FieldAccess) expr;
			IVariableBinding binding = fieldAccess.resolveFieldBinding();
			if (binding != null) {
				int modifiers = binding.getModifiers();
				return Modifier.isFinal(modifiers) || Modifier.isStatic(modifiers) || binding.isEnumConstant();
			}
		}
		
		// Array creation with initializer containing only constants
		if (expr instanceof ArrayCreation) {
			ArrayCreation arrayCreation = (ArrayCreation) expr;
			ArrayInitializer initializer = arrayCreation.getInitializer();
			if (initializer != null) {
				List<?> expressions = initializer.expressions();
				return expressions.stream().allMatch(e -> isConstantExpression((Expression) e));
			}
		}
		
		// Array initializer
		if (expr instanceof ArrayInitializer) {
			ArrayInitializer initializer = (ArrayInitializer) expr;
			List<?> expressions = initializer.expressions();
			return expressions.stream().allMatch(e -> isConstantExpression((Expression) e));
		}
		
		// Collection factory methods: List.of(...), Set.of(...), Arrays.asList(...), Map.of(...)
		if (expr instanceof MethodInvocation) {
			MethodInvocation mi = (MethodInvocation) expr;
			String methodName = mi.getName().getIdentifier();
			Expression receiver = mi.getExpression();
			
			if (methodName.equals("of") || methodName.equals("asList")) {
				if (receiver instanceof SimpleName) {
					String receiverName = ((SimpleName) receiver).getIdentifier();
					if (receiverName.equals("List") || receiverName.equals("Set") || 
						receiverName.equals("Arrays") || receiverName.equals("Map")) {
						List<?> arguments = mi.arguments();
						return arguments.stream().allMatch(arg -> isConstantExpression((Expression) arg));
					}
				}
			}
			
			// Method call on string literal: "test".getBytes()
			if (receiver instanceof StringLiteral) {
				return true;
			}
		}
		
		return false;
	}

	private boolean canOptimize(Expression condition) {
		// Check for prefix expression (!condition)
		if (condition instanceof PrefixExpression) {
			PrefixExpression prefix = (PrefixExpression) condition;
			return prefix.getOperator() == PrefixExpression.Operator.NOT;
		}
		
		// Check for infix expression (==, !=)
		if (condition instanceof InfixExpression) {
			InfixExpression infix = (InfixExpression) condition;
			InfixExpression.Operator op = infix.getOperator();
			return op == InfixExpression.Operator.EQUALS || op == InfixExpression.Operator.NOT_EQUALS;
		}
		
		// Check for .equals() method call
		if (condition instanceof MethodInvocation) {
			MethodInvocation methodInv = (MethodInvocation) condition;
			return "equals".equals(methodInv.getName().getIdentifier()) && methodInv.arguments().size() == 1;
		}
		
		return false;
	}

	@Override
	protected void process2Rewrite(TextEditGroup group, ASTRewrite rewriter, AST ast, ImportRewrite importRewriter,
			JunitHolder junitHolder) {
		if (!(junitHolder.minv instanceof MethodInvocation)) {
			return;
		}
		
		MethodInvocation mi = junitHolder.getMethodInvocation();
		List<?> arguments = mi.arguments();
		
		if (arguments.isEmpty()) {
			return;
		}
		
		String methodName = mi.getName().getIdentifier();
		
		// Check if this is a parameter order correction case
		if (METHODS_WITH_EXPECTED_ACTUAL.contains(methodName)) {
			swapParametersIfNeeded(mi, rewriter, group);
			return;
		}
		
		// Handle assertTrue/assertFalse optimization
		boolean isTrue = "assertTrue".equals(methodName);
		
		// Get the condition and message
		Expression condition = null;
		Expression message = null;
		
		if (arguments.size() == 1) {
			condition = (Expression) arguments.get(0);
		} else if (arguments.size() == 2) {
			Expression firstArg = (Expression) arguments.get(0);
			ITypeBinding firstArgType = firstArg.resolveTypeBinding();
			if (firstArgType != null && "java.lang.String".equals(firstArgType.getQualifiedName())) {
				message = firstArg;
				condition = (Expression) arguments.get(1);
			} else {
				condition = firstArg;
				message = (Expression) arguments.get(1);
			}
		}
		
		if (condition == null) {
			return;
		}
		
		// Handle prefix expression: assertTrue(!x) → assertFalse(x)
		if (condition instanceof PrefixExpression) {
			PrefixExpression prefix = (PrefixExpression) condition;
			if (prefix.getOperator() == PrefixExpression.Operator.NOT) {
				// Flip assertTrue/assertFalse and remove negation
				String newMethodName = isTrue ? "assertFalse" : "assertTrue";
				rewriter.set(mi, MethodInvocation.NAME_PROPERTY, ast.newSimpleName(newMethodName), group);
				
				// Replace condition with the operand (removing the !)
				ListRewrite argsRewrite = rewriter.getListRewrite(mi, MethodInvocation.ARGUMENTS_PROPERTY);
				argsRewrite.replace((ASTNode) condition, rewriter.createCopyTarget(prefix.getOperand()), group);
			}
			return;
		}
		
		// Handle infix expression: ==, !=
		if (condition instanceof InfixExpression) {
			InfixExpression infix = (InfixExpression) condition;
			InfixExpression.Operator op = infix.getOperator();
			Expression left = infix.getLeftOperand();
			Expression right = infix.getRightOperand();
			
			// Check for null comparisons
			boolean leftIsNull = left instanceof NullLiteral;
			boolean rightIsNull = right instanceof NullLiteral;
			
			if (leftIsNull || rightIsNull) {
				Expression nonNullExpr = leftIsNull ? right : left;
				String newMethodName = null;
				
				if (op == InfixExpression.Operator.EQUALS) {
					// assertTrue(x == null) → assertNull(x)
					// assertFalse(x == null) → assertNotNull(x)
					newMethodName = isTrue ? "assertNull" : "assertNotNull";
				} else if (op == InfixExpression.Operator.NOT_EQUALS) {
					// assertTrue(x != null) → assertNotNull(x)
					// assertFalse(x != null) → assertNull(x)
					newMethodName = isTrue ? "assertNotNull" : "assertNull";
				}
				
				if (newMethodName != null) {
					rewriter.set(mi, MethodInvocation.NAME_PROPERTY, ast.newSimpleName(newMethodName), group);
					ListRewrite argsRewrite = rewriter.getListRewrite(mi, MethodInvocation.ARGUMENTS_PROPERTY);
					
					// Create new argument list
					argsRewrite.remove((ASTNode) condition, group);
					if (message != null) {
						argsRewrite.insertLast(rewriter.createCopyTarget(nonNullExpr), group);
					} else {
						argsRewrite.insertFirst(rewriter.createCopyTarget(nonNullExpr), group);
					}
				}
			} else {
				// Handle equality checks between non-null values
				boolean isPrimitiveComparison = isPrimitiveComparison(left, right);
				String newMethodName = null;
				
				if (op == InfixExpression.Operator.EQUALS) {
					if (isPrimitiveComparison) {
						// assertTrue(a == b) → assertEquals(a, b)
						// assertFalse(a == b) → assertNotEquals(a, b)
						newMethodName = isTrue ? "assertEquals" : "assertNotEquals";
					} else {
						// Objects: assertTrue(obj1 == obj2) → assertSame(obj1, obj2)
						// assertFalse(obj1 == obj2) → assertNotSame(obj1, obj2)
						newMethodName = isTrue ? "assertSame" : "assertNotSame";
					}
				} else if (op == InfixExpression.Operator.NOT_EQUALS) {
					if (isPrimitiveComparison) {
						// assertTrue(a != b) → assertNotEquals(a, b)
						// assertFalse(a != b) → assertEquals(a, b)
						newMethodName = isTrue ? "assertNotEquals" : "assertEquals";
					} else {
						// Objects: assertTrue(obj1 != obj2) → assertNotSame(obj1, obj2)
						// assertFalse(obj1 != obj2) → assertSame(obj1, obj2)
						newMethodName = isTrue ? "assertNotSame" : "assertSame";
					}
				}
				
				if (newMethodName != null) {
					rewriter.set(mi, MethodInvocation.NAME_PROPERTY, ast.newSimpleName(newMethodName), group);
					ListRewrite argsRewrite = rewriter.getListRewrite(mi, MethodInvocation.ARGUMENTS_PROPERTY);
					
					// Replace condition with two separate arguments
					argsRewrite.remove((ASTNode) condition, group);
					if (message != null) {
						// Message comes last in JUnit 5
						argsRewrite.insertBefore(rewriter.createCopyTarget(left), (ASTNode) message, group);
						argsRewrite.insertBefore(rewriter.createCopyTarget(right), (ASTNode) message, group);
					} else {
						argsRewrite.insertFirst(rewriter.createCopyTarget(left), group);
						argsRewrite.insertLast(rewriter.createCopyTarget(right), group);
					}
				}
			}
			return;
		}
		
		// Handle .equals() method call
		if (condition instanceof MethodInvocation) {
			MethodInvocation methodInv = (MethodInvocation) condition;
			if ("equals".equals(methodInv.getName().getIdentifier()) && methodInv.arguments().size() == 1) {
				Expression receiver = methodInv.getExpression();
				Expression argument = (Expression) methodInv.arguments().get(0);
				
				// assertTrue(a.equals(b)) → assertEquals(b, a)
				// assertFalse(a.equals(b)) → assertNotEquals(b, a)
				String newMethodName = isTrue ? "assertEquals" : "assertNotEquals";
				rewriter.set(mi, MethodInvocation.NAME_PROPERTY, ast.newSimpleName(newMethodName), group);
				
				ListRewrite argsRewrite = rewriter.getListRewrite(mi, MethodInvocation.ARGUMENTS_PROPERTY);
				argsRewrite.remove((ASTNode) condition, group);
				
				if (message != null) {
					// JUnit 5: expected, actual, message
					argsRewrite.insertBefore(rewriter.createCopyTarget(argument), (ASTNode) message, group);
					if (receiver != null) {
						argsRewrite.insertBefore(rewriter.createCopyTarget(receiver), (ASTNode) message, group);
					}
				} else {
					argsRewrite.insertFirst(rewriter.createCopyTarget(argument), group);
					if (receiver != null) {
						argsRewrite.insertLast(rewriter.createCopyTarget(receiver), group);
					}
				}
			}
		}
	}

	/**
	 * Swaps expected/actual parameters if they are in the wrong order.
	 * Expected (constant) should come first, actual (computed) should come second.
	 */
	private void swapParametersIfNeeded(MethodInvocation mi, ASTRewrite rewriter, TextEditGroup group) {
		List<?> arguments = mi.arguments();
		
		if (arguments.size() < 2) {
			return;
		}
		
		Expression first = (Expression) arguments.get(0);
		Expression second = (Expression) arguments.get(1);
		
		// Check if first argument is a String message (JUnit 4 style)
		ITypeBinding firstType = first.resolveTypeBinding();
		boolean firstIsMessage = firstType != null && "java.lang.String".equals(firstType.getQualifiedName());
		
		Expression expectedParam;
		Expression actualParam;
		int expectedIndex;
		int actualIndex;
		
		if (firstIsMessage && arguments.size() >= 3) {
			// JUnit 4: message, expected, actual
			expectedParam = (Expression) arguments.get(1);
			actualParam = (Expression) arguments.get(2);
			expectedIndex = 1;
			actualIndex = 2;
		} else {
			// JUnit 5: expected, actual [, message]
			expectedParam = first;
			actualParam = second;
			expectedIndex = 0;
			actualIndex = 1;
		}
		
		// If expected is not constant but actual is constant, swap them
		if (!isConstantExpression(expectedParam) && isConstantExpression(actualParam)) {
			ListRewrite argsRewrite = rewriter.getListRewrite(mi, MethodInvocation.ARGUMENTS_PROPERTY);
			Expression newExpected = (Expression) rewriter.createCopyTarget(actualParam);
			Expression newActual = (Expression) rewriter.createCopyTarget(expectedParam);
			argsRewrite.replace((ASTNode) arguments.get(expectedIndex), newExpected, group);
			argsRewrite.replace((ASTNode) arguments.get(actualIndex), newActual, group);
		}
	}

	private boolean isPrimitiveComparison(Expression left, Expression right) {
		ITypeBinding leftType = left.resolveTypeBinding();
		ITypeBinding rightType = right.resolveTypeBinding();
		return (leftType != null && leftType.isPrimitive()) || (rightType != null && rightType.isPrimitive());
	}

	@Override
	public String getPreview(boolean afterRefactoring) {
		if (afterRefactoring) {
			return """
					Assertions.assertEquals(5, result);
					Assertions.assertNull(obj);
					Assertions.assertFalse(condition);
					"""; //$NON-NLS-1$
		}
		return """
				Assertions.assertTrue(result == 5);
				Assertions.assertTrue(obj == null);
				Assertions.assertTrue(!condition);
				"""; //$NON-NLS-1$
	}

	@Override
	public String toString() {
		return "AssertOptimization"; //$NON-NLS-1$
	}
}