RuleExpectedExceptionJUnitPlugin.java
/*******************************************************************************
* Copyright (c) 2025 Carsten Hammer.
*
* This program and the accompanying materials
* are made available under the terms of the Eclipse Public License 2.0
* which accompanies this distribution, and is available at
* https://www.eclipse.org/legal/epl-2.0/
*
* SPDX-License-Identifier: EPL-2.0
*
* Contributors:
* Carsten Hammer
*******************************************************************************/
package org.sandbox.jdt.internal.corext.fix.helper;
import static org.sandbox.jdt.internal.corext.fix.helper.lib.JUnitConstants.*;
/*-
* #%L
* Sandbox junit cleanup
* %%
* Copyright (C) 2024 hammer
* %%
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
* #L%
*/
import java.util.Collection;
import java.util.List;
import java.util.Set;
import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTNode;
import org.eclipse.jdt.core.dom.Block;
import org.eclipse.jdt.core.dom.CompilationUnit;
import org.eclipse.jdt.core.dom.Expression;
import org.eclipse.jdt.core.dom.ExpressionStatement;
import org.eclipse.jdt.core.dom.FieldDeclaration;
import org.eclipse.jdt.core.dom.ITypeBinding;
import org.eclipse.jdt.core.dom.LambdaExpression;
import org.eclipse.jdt.core.dom.MethodDeclaration;
import org.eclipse.jdt.core.dom.MethodInvocation;
import org.eclipse.jdt.core.dom.SimpleName;
import org.eclipse.jdt.core.dom.Statement;
import org.eclipse.jdt.core.dom.Type;
import org.eclipse.jdt.core.dom.TypeDeclaration;
import org.eclipse.jdt.core.dom.TypeLiteral;
import org.eclipse.jdt.core.dom.VariableDeclarationFragment;
import org.eclipse.jdt.core.dom.VariableDeclarationStatement;
import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
import org.eclipse.jdt.core.dom.rewrite.ImportRewrite;
import org.eclipse.jdt.internal.corext.dom.ASTNodes;
import org.eclipse.jdt.internal.corext.fix.CompilationUnitRewriteOperationsFixCore.CompilationUnitRewriteOperationWithSourceRange;
import org.eclipse.text.edits.TextEditGroup;
import org.sandbox.jdt.internal.common.HelperVisitor;
import org.sandbox.jdt.internal.common.ReferenceHolder;
import org.sandbox.jdt.internal.corext.fix.JUnitCleanUpFixCore;
import org.sandbox.jdt.internal.corext.fix.helper.lib.AbstractTool;
import org.sandbox.jdt.internal.corext.fix.helper.lib.JunitHolder;
/**
* Plugin to migrate JUnit 4 ExpectedException rule to JUnit 5 assertThrows.
*/
public class RuleExpectedExceptionJUnitPlugin extends AbstractTool<ReferenceHolder<Integer, JunitHolder>> {
@Override
public void find(JUnitCleanUpFixCore fixcore, CompilationUnit compilationUnit,
Set<CompilationUnitRewriteOperationWithSourceRange> operations, Set<ASTNode> nodesprocessed) {
ReferenceHolder<Integer, JunitHolder> dataHolder = new ReferenceHolder<>();
HelperVisitor.forField()
.withAnnotation(ORG_JUNIT_RULE)
.ofType(ORG_JUNIT_RULES_EXPECTED_EXCEPTION)
.in(compilationUnit)
.excluding(nodesprocessed)
.processEach(dataHolder, (visited, aholder) -> processFoundNode(fixcore, operations, (FieldDeclaration) visited, aholder));
}
private boolean processFoundNode(JUnitCleanUpFixCore fixcore,
Set<CompilationUnitRewriteOperationWithSourceRange> operations, FieldDeclaration node,
ReferenceHolder<Integer, JunitHolder> dataHolder) {
JunitHolder mh = new JunitHolder();
VariableDeclarationFragment fragment = (VariableDeclarationFragment) node.fragments().get(0);
if (fragment.resolveBinding() == null) {
// Return true to continue processing other fields
return true;
}
ITypeBinding binding = fragment.resolveBinding().getType();
if (binding != null && ORG_JUNIT_RULES_EXPECTED_EXCEPTION.equals(binding.getQualifiedName())) {
mh.minv = node;
dataHolder.put(dataHolder.size(), mh);
operations.add(fixcore.rewrite(dataHolder));
}
// Return true to continue processing other fields
return true;
}
@Override
protected
void process2Rewrite(TextEditGroup group, ASTRewrite rewriter, AST ast, ImportRewrite importRewriter,
JunitHolder junitHolder) {
FieldDeclaration field = junitHolder.getFieldDeclaration();
TypeDeclaration parentClass = ASTNodes.getParent(field, TypeDeclaration.class);
VariableDeclarationFragment originalFragment = (VariableDeclarationFragment) field.fragments().get(0);
String fieldName = originalFragment.getName().getIdentifier();
// Remove the field declaration
rewriter.remove(field, group);
// Remove old imports
importRewriter.removeImport(ORG_JUNIT_RULE);
importRewriter.removeImport(ORG_JUNIT_RULES_EXPECTED_EXCEPTION);
// Add new imports
importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, "assertThrows", false);
// Transform all test methods that use the ExpectedException field
for (MethodDeclaration method : parentClass.getMethods()) {
transformTestMethod(method, fieldName, rewriter, ast, group, importRewriter, parentClass);
}
}
private void transformTestMethod(MethodDeclaration method, String fieldName, ASTRewrite rewriter, AST ast,
TextEditGroup group, ImportRewrite importRewriter, TypeDeclaration parentClass) {
Block methodBody = method.getBody();
if (methodBody == null) {
return;
}
List<Statement> statements = methodBody.statements();
if (statements.isEmpty()) {
return;
}
// Find expect() and expectMessage() calls
ExpectedExceptionInfo info = findExpectedExceptionCalls(statements, fieldName);
if (info.expectCall == null) {
// This method doesn't use the ExpectedException field
return;
}
// Generate a unique variable name for the exception if we need to check the message or cause
String exceptionVarName = null;
if (info.expectMessageCall != null || info.expectCauseCall != null) {
Collection<String> usedNames = getUsedVariableNames(method);
exceptionVarName = generateUniqueVariableName("exception", usedNames);
}
// Create assertThrows call
MethodInvocation assertThrowsCall = ast.newMethodInvocation();
assertThrowsCall.setName(ast.newSimpleName("assertThrows"));
// Add exception class as first argument
Expression exceptionClass = (Expression) ASTNode.copySubtree(ast,
(Expression) info.expectCall.arguments().get(0));
assertThrowsCall.arguments().add(exceptionClass);
// Create lambda with remaining statements
LambdaExpression lambda = ast.newLambdaExpression();
lambda.setParentheses(true);
Block lambdaBody = ast.newBlock();
// Copy all statements after the expect/expectMessage calls
int startIndex = info.lastExpectStatementIndex + 1;
if (startIndex >= statements.size()) {
// Edge case: expect() is the last statement, no code to throw exception
// This would create an empty lambda that never throws, causing test to fail
// Skip transformation for this edge case
return;
}
for (int i = startIndex; i < statements.size(); i++) {
Statement stmt = statements.get(i);
lambdaBody.statements().add(ASTNode.copySubtree(ast, stmt));
}
lambda.setBody(lambdaBody);
assertThrowsCall.arguments().add(lambda);
// Create the new statement
Statement newStatement;
if (exceptionVarName != null) {
// Need to capture exception for message check
// ExceptionType exceptionVar = assertThrows(ExceptionType.class, () -> { ... });
VariableDeclarationFragment fragment = ast.newVariableDeclarationFragment();
fragment.setName(ast.newSimpleName(exceptionVarName));
fragment.setInitializer(assertThrowsCall);
VariableDeclarationStatement varDecl = ast.newVariableDeclarationStatement(fragment);
// Extract the exception type from the class literal (use the Type directly to preserve simple name)
Type exceptionType = extractExceptionType(info.expectCall);
varDecl.setType((Type) ASTNode.copySubtree(ast, exceptionType));
newStatement = varDecl;
} else {
// No message check needed, just call assertThrows
newStatement = ast.newExpressionStatement(assertThrowsCall);
}
// Remove old expect/expectMessage calls and statements after them
for (int i = statements.size() - 1; i >= info.firstExpectStatementIndex; i--) {
rewriter.remove(statements.get(i), group);
}
// Insert the new assertThrows statement
rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertLast(newStatement, group);
// If there's a message expectation, add the assertion
if (info.expectMessageCall != null && exceptionVarName != null) {
Expression messageArg = (Expression) info.expectMessageCall.arguments().get(0);
// Create: assertEquals("message", exception.getMessage());
MethodInvocation getMessageCall = ast.newMethodInvocation();
getMessageCall.setExpression(ast.newSimpleName(exceptionVarName));
getMessageCall.setName(ast.newSimpleName("getMessage"));
MethodInvocation assertEqualsCall = ast.newMethodInvocation();
assertEqualsCall.setName(ast.newSimpleName("assertEquals"));
assertEqualsCall.arguments().add(ASTNode.copySubtree(ast, messageArg));
assertEqualsCall.arguments().add(getMessageCall);
ExpressionStatement assertStatement = ast.newExpressionStatement(assertEqualsCall);
rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertLast(assertStatement, group);
// Add assertEquals import
importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, "assertEquals", false);
}
// If there's a cause expectation, add the assertion
if (info.expectCauseCall != null && exceptionVarName != null) {
// Check if expectCauseCall has arguments before accessing
if (!info.expectCauseCall.arguments().isEmpty()) {
Expression causeArg = (Expression) info.expectCauseCall.arguments().get(0);
Expression causeClass = extractCauseClass(causeArg);
if (causeClass != null) {
// Create: exception.getCause()
MethodInvocation getCauseCall = ast.newMethodInvocation();
getCauseCall.setExpression(ast.newSimpleName(exceptionVarName));
getCauseCall.setName(ast.newSimpleName("getCause"));
// Create: assertInstanceOf(CauseClass.class, exception.getCause());
MethodInvocation assertInstanceOfCall = ast.newMethodInvocation();
assertInstanceOfCall.setName(ast.newSimpleName("assertInstanceOf"));
assertInstanceOfCall.arguments().add(ASTNode.copySubtree(ast, causeClass));
assertInstanceOfCall.arguments().add(getCauseCall);
ExpressionStatement assertStatement = ast.newExpressionStatement(assertInstanceOfCall);
rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertLast(assertStatement, group);
// Add assertInstanceOf import
importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, "assertInstanceOf", false);
} else {
// Unsupported matcher - log warning
System.err.println("WARNING: RuleExpectedExceptionJUnitPlugin - Unsupported expectCause matcher in method '"
+ method.getName().getIdentifier()
+ "'. Only Hamcrest instanceOf() and isA() matchers are supported. Manual migration of the cause expectation may be required.");
}
}
}
}
private ExpectedExceptionInfo findExpectedExceptionCalls(List<Statement> statements, String fieldName) {
ExpectedExceptionInfo info = new ExpectedExceptionInfo();
for (int i = 0; i < statements.size(); i++) {
Statement stmt = statements.get(i);
if (!(stmt instanceof ExpressionStatement)) {
continue;
}
Expression expr = ((ExpressionStatement) stmt).getExpression();
if (!(expr instanceof MethodInvocation)) {
continue;
}
MethodInvocation invocation = (MethodInvocation) expr;
Expression expression = invocation.getExpression();
if (expression == null || !(expression instanceof SimpleName)) {
continue;
}
SimpleName receiver = (SimpleName) expression;
if (!fieldName.equals(receiver.getIdentifier())) {
continue;
}
String methodName = invocation.getName().getIdentifier();
if ("expect".equals(methodName)) {
info.expectCall = invocation;
if (info.firstExpectStatementIndex == -1) {
info.firstExpectStatementIndex = i;
}
info.lastExpectStatementIndex = i;
} else if ("expectMessage".equals(methodName)) {
info.expectMessageCall = invocation;
if (info.firstExpectStatementIndex == -1) {
info.firstExpectStatementIndex = i;
}
info.lastExpectStatementIndex = i;
} else if ("expectCause".equals(methodName)) {
info.expectCauseCall = invocation;
if (info.firstExpectStatementIndex == -1) {
info.firstExpectStatementIndex = i;
}
info.lastExpectStatementIndex = i;
}
}
return info;
}
private Type extractExceptionType(MethodInvocation expectCall) {
// The argument is typically a TypeLiteral like IllegalArgumentException.class
if (!expectCall.arguments().isEmpty()) {
Expression arg = (Expression) expectCall.arguments().get(0);
// Extract the Type from the TypeLiteral
if (arg instanceof TypeLiteral typeLiteral) {
return typeLiteral.getType();
}
}
return null;
}
/**
* Extracts the cause exception class from a Hamcrest matcher expression.
*
* Supported Hamcrest matchers:
* - org.hamcrest.Matchers.instanceOf(ExceptionClass.class)
* - org.hamcrest.Matchers.isA(ExceptionClass.class)
*
* Unsupported matchers (will return null):
* - any(Class.class)
* - notNullValue()
* - Custom matchers
*
* @param causeArg the expression passed to expectCause()
* @return the class literal expression, or null if the matcher is not supported
*/
private Expression extractCauseClass(Expression causeArg) {
if (causeArg instanceof MethodInvocation methodInv) {
String methodName = methodInv.getName().getIdentifier();
if (("instanceOf".equals(methodName) || "isA".equals(methodName)) && !methodInv.arguments().isEmpty()) {
// Extract the class literal argument
Expression arg = (Expression) methodInv.arguments().get(0);
return arg;
}
}
return null;
}
private String extractExceptionTypeName(MethodInvocation expectCall) {
// The argument is typically a TypeLiteral like IllegalArgumentException.class
if (!expectCall.arguments().isEmpty()) {
Expression arg = (Expression) expectCall.arguments().get(0);
// Use TypeLiteral API for robust type extraction
if (arg instanceof TypeLiteral typeLiteral) {
Type type = typeLiteral.getType();
if (type != null) {
ITypeBinding typeBinding = type.resolveBinding();
if (typeBinding != null) {
// Try qualified name first, fall back to simple name
String qualifiedName = typeBinding.getQualifiedName();
if (qualifiedName != null && !qualifiedName.isEmpty()) {
return qualifiedName;
}
String name = typeBinding.getName();
if (name != null && !name.isEmpty()) {
return name;
}
}
// Fallback: use the type's string representation
return type.toString();
}
}
// Fallback for non-TypeLiteral expressions
String argStr = arg.toString();
if (argStr.endsWith(".class")) {
return argStr.substring(0, argStr.length() - ".class".length());
}
}
return "Exception";
}
private String generateUniqueVariableName(String baseName, Collection<String> usedNames) {
if (!usedNames.contains(baseName)) {
return baseName;
}
int counter = 1;
String candidateName;
do {
candidateName = baseName + counter;
counter++;
} while (usedNames.contains(candidateName));
return candidateName;
}
@Override
public String getPreview(boolean afterRefactoring) {
if (afterRefactoring) {
return """
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testException() {
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
throw new IllegalArgumentException("Invalid argument");
});
assertEquals("Invalid argument", exception.getMessage());
}
}
"""; //$NON-NLS-1$
}
return """
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
public class MyTest {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void testException() {
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("Invalid argument");
throw new IllegalArgumentException("Invalid argument");
}
}
"""; //$NON-NLS-1$
}
@Override
public String toString() {
return "RuleExpectedException"; //$NON-NLS-1$
}
private static class ExpectedExceptionInfo {
MethodInvocation expectCall;
MethodInvocation expectMessageCall;
MethodInvocation expectCauseCall;
int firstExpectStatementIndex = -1;
int lastExpectStatementIndex = -1;
}
}