RuleExpectedExceptionJUnitPlugin.java
/*******************************************************************************
* Copyright (c) 2025 Carsten Hammer.
*
* This program and the accompanying materials
* are made available under the terms of the Eclipse Public License 2.0
* which accompanies this distribution, and is available at
* https://www.eclipse.org/legal/epl-2.0/
*
* SPDX-License-Identifier: EPL-2.0
*
* Contributors:
* Carsten Hammer
*******************************************************************************/
package org.sandbox.jdt.internal.corext.fix.helper;
import static org.sandbox.jdt.internal.corext.fix.helper.lib.JUnitConstants.*;
import java.util.Collection;
import java.util.List;
import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTNode;
import org.eclipse.jdt.core.dom.Block;
import org.eclipse.jdt.core.dom.Expression;
import org.eclipse.jdt.core.dom.ExpressionStatement;
import org.eclipse.jdt.core.dom.FieldDeclaration;
import org.eclipse.jdt.core.dom.ITypeBinding;
import org.eclipse.jdt.core.dom.LambdaExpression;
import org.eclipse.jdt.core.dom.MethodDeclaration;
import org.eclipse.jdt.core.dom.MethodInvocation;
import org.eclipse.jdt.core.dom.SimpleName;
import org.eclipse.jdt.core.dom.Statement;
import org.eclipse.jdt.core.dom.Type;
import org.eclipse.jdt.core.dom.TypeDeclaration;
import org.eclipse.jdt.core.dom.TypeLiteral;
import org.eclipse.jdt.core.dom.VariableDeclarationFragment;
import org.eclipse.jdt.core.dom.VariableDeclarationStatement;
import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
import org.eclipse.jdt.core.dom.rewrite.ImportRewrite;
import org.eclipse.jdt.internal.corext.dom.ASTNodes;
import org.eclipse.text.edits.TextEditGroup;
import org.sandbox.jdt.internal.corext.fix.helper.lib.JunitHolder;
import org.sandbox.jdt.internal.corext.fix.helper.lib.TriggerPatternCleanupPlugin;
import org.sandbox.jdt.triggerpattern.api.CleanupPattern;
import org.sandbox.jdt.triggerpattern.api.Match;
import org.sandbox.jdt.triggerpattern.api.PatternKind;
/**
* Plugin to migrate JUnit 4 ExpectedException rule to JUnit 5 assertThrows.
*
* @since 1.3.0
*/
@CleanupPattern(value = "@Rule public ExpectedException $name", kind = PatternKind.FIELD, qualifiedType = ORG_JUNIT_RULES_EXPECTED_EXCEPTION, cleanupId = "cleanup.junit.ruleexpectedexception", description = "Migrate @Rule ExpectedException to assertThrows()", displayName = "JUnit 4 @Rule ExpectedException \u2192 JUnit 5 assertThrows()")
public class RuleExpectedExceptionJUnitPlugin extends TriggerPatternCleanupPlugin {
@Override
protected JunitHolder createHolder(Match match) {
FieldDeclaration fieldDecl = (FieldDeclaration) match.getMatchedNode();
VariableDeclarationFragment fragment = (VariableDeclarationFragment) fieldDecl.fragments().get(0);
if (fragment.resolveBinding() == null) {
return null;
}
ITypeBinding binding = fragment.resolveBinding().getType();
if (binding == null || !ORG_JUNIT_RULES_EXPECTED_EXCEPTION.equals(binding.getQualifiedName())) {
return null;
}
JunitHolder holder = new JunitHolder();
holder.setMinv(fieldDecl);
return holder;
}
@Override
protected void process2Rewrite(TextEditGroup group, ASTRewrite rewriter, AST ast, ImportRewrite importRewriter,
JunitHolder junitHolder) {
FieldDeclaration field = junitHolder.getFieldDeclaration();
TypeDeclaration parentClass = ASTNodes.getParent(field, TypeDeclaration.class);
VariableDeclarationFragment originalFragment = (VariableDeclarationFragment) field.fragments().get(0);
String fieldName = originalFragment.getName().getIdentifier();
// Remove the field declaration
rewriter.remove(field, group);
// Remove old imports
importRewriter.removeImport(ORG_JUNIT_RULE);
importRewriter.removeImport(ORG_JUNIT_RULES_EXPECTED_EXCEPTION);
// Add new imports
importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, "assertThrows", false);
// Transform all test methods that use the ExpectedException field
for (MethodDeclaration method : parentClass.getMethods()) {
transformTestMethod(method, fieldName, rewriter, ast, group, importRewriter, parentClass);
}
}
private void transformTestMethod(MethodDeclaration method, String fieldName, ASTRewrite rewriter, AST ast,
TextEditGroup group, ImportRewrite importRewriter, TypeDeclaration parentClass) {
Block methodBody = method.getBody();
if (methodBody == null) {
return;
}
List<Statement> statements = methodBody.statements();
if (statements.isEmpty()) {
return;
}
// Find expect() and expectMessage() calls
ExpectedExceptionInfo info = findExpectedExceptionCalls(statements, fieldName);
if (info.getExpectCall() == null) {
// This method doesn't use the ExpectedException field
return;
}
// Generate a unique variable name for the exception if we need to check the
// message or cause
String exceptionVarName = null;
if (info.getExpectMessageCall() != null || info.getExpectCauseCall() != null) {
Collection<String> usedNames = getUsedVariableNames(method);
exceptionVarName = generateUniqueVariableName("exception", usedNames);
}
// Create assertThrows call
MethodInvocation assertThrowsCall = ast.newMethodInvocation();
assertThrowsCall.setName(ast.newSimpleName("assertThrows"));
// Add exception class as first argument
Expression exceptionClass = (Expression) ASTNode.copySubtree(ast,
(Expression) info.getExpectCall().arguments().get(0));
assertThrowsCall.arguments().add(exceptionClass);
// Create lambda with remaining statements
LambdaExpression lambda = ast.newLambdaExpression();
lambda.setParentheses(true);
Block lambdaBody = ast.newBlock();
// Copy all statements after the expect/expectMessage calls
int startIndex = info.getLastExpectStatementIndex() + 1;
if (startIndex >= statements.size()) {
// Edge case: expect() is the last statement, no code to throw exception
// This would create an empty lambda that never throws, causing test to fail
// Skip transformation for this edge case
return;
}
for (int i = startIndex; i < statements.size(); i++) {
Statement stmt = statements.get(i);
lambdaBody.statements().add(ASTNode.copySubtree(ast, stmt));
}
lambda.setBody(lambdaBody);
assertThrowsCall.arguments().add(lambda);
// Create the new statement
Statement newStatement;
if (exceptionVarName != null) {
// Need to capture exception for message check
// ExceptionType exceptionVar = assertThrows(ExceptionType.class, () -> { ...
// });
VariableDeclarationFragment fragment = ast.newVariableDeclarationFragment();
fragment.setName(ast.newSimpleName(exceptionVarName));
fragment.setInitializer(assertThrowsCall);
VariableDeclarationStatement varDecl = ast.newVariableDeclarationStatement(fragment);
// Extract the exception type from the class literal (use the Type directly to
// preserve simple name)
Type exceptionType = extractExceptionType(info.getExpectCall());
varDecl.setType((Type) ASTNode.copySubtree(ast, exceptionType));
newStatement = varDecl;
} else {
// No message check needed, just call assertThrows
newStatement = ast.newExpressionStatement(assertThrowsCall);
}
// Remove old expect/expectMessage calls and statements after them
for (int i = statements.size() - 1; i >= info.getFirstExpectStatementIndex(); i--) {
rewriter.remove(statements.get(i), group);
}
// Insert the new assertThrows statement
rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertLast(newStatement, group);
// If there's a message expectation, add the assertion
if (info.getExpectMessageCall() != null && exceptionVarName != null) {
Expression messageArg = (Expression) info.getExpectMessageCall().arguments().get(0);
// Create: assertEquals("message", exception.getMessage());
MethodInvocation getMessageCall = ast.newMethodInvocation();
getMessageCall.setExpression(ast.newSimpleName(exceptionVarName));
getMessageCall.setName(ast.newSimpleName("getMessage"));
MethodInvocation assertEqualsCall = ast.newMethodInvocation();
assertEqualsCall.setName(ast.newSimpleName("assertEquals"));
assertEqualsCall.arguments().add(ASTNode.copySubtree(ast, messageArg));
assertEqualsCall.arguments().add(getMessageCall);
ExpressionStatement assertStatement = ast.newExpressionStatement(assertEqualsCall);
rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertLast(assertStatement, group);
// Add assertEquals import
importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, "assertEquals", false);
}
// If there's a cause expectation, add the assertion
if (info.getExpectCauseCall() != null && exceptionVarName != null) {
// Check if expectCauseCall has arguments before accessing
if (!info.getExpectCauseCall().arguments().isEmpty()) {
Expression causeArg = (Expression) info.getExpectCauseCall().arguments().get(0);
Expression causeClass = extractCauseClass(causeArg);
if (causeClass != null) {
// Create: exception.getCause()
MethodInvocation getCauseCall = ast.newMethodInvocation();
getCauseCall.setExpression(ast.newSimpleName(exceptionVarName));
getCauseCall.setName(ast.newSimpleName("getCause"));
// Create: assertInstanceOf(CauseClass.class, exception.getCause());
MethodInvocation assertInstanceOfCall = ast.newMethodInvocation();
assertInstanceOfCall.setName(ast.newSimpleName("assertInstanceOf"));
assertInstanceOfCall.arguments().add(ASTNode.copySubtree(ast, causeClass));
assertInstanceOfCall.arguments().add(getCauseCall);
ExpressionStatement assertStatement = ast.newExpressionStatement(assertInstanceOfCall);
rewriter.getListRewrite(methodBody, Block.STATEMENTS_PROPERTY).insertLast(assertStatement, group);
// Add assertInstanceOf import
importRewriter.addStaticImport(ORG_JUNIT_JUPITER_API_ASSERTIONS, "assertInstanceOf", false);
}
// Unsupported matchers are silently skipped - manual migration required
}
}
}
private ExpectedExceptionInfo findExpectedExceptionCalls(List<Statement> statements, String fieldName) {
ExpectedExceptionInfo info = new ExpectedExceptionInfo();
for (int i = 0; i < statements.size(); i++) {
Statement stmt = statements.get(i);
if (!(stmt instanceof ExpressionStatement)) {
continue;
}
Expression expr = ((ExpressionStatement) stmt).getExpression();
if (!(expr instanceof MethodInvocation)) {
continue;
}
MethodInvocation invocation = (MethodInvocation) expr;
Expression expression = invocation.getExpression();
if (!(expression instanceof SimpleName receiver)) {
continue;
}
if (!fieldName.equals(receiver.getIdentifier())) {
continue;
}
String methodName = invocation.getName().getIdentifier();
if ("expect".equals(methodName)) {
info.setExpectCall(invocation);
info.updateStatementIndices(i);
} else if ("expectMessage".equals(methodName)) {
info.setExpectMessageCall(invocation);
info.updateStatementIndices(i);
} else if ("expectCause".equals(methodName)) {
info.setExpectCauseCall(invocation);
info.updateStatementIndices(i);
}
}
return info;
}
private Type extractExceptionType(MethodInvocation expectCall) {
// The argument is typically a TypeLiteral like IllegalArgumentException.class
if (!expectCall.arguments().isEmpty()) {
Expression arg = (Expression) expectCall.arguments().get(0);
// Extract the Type from the TypeLiteral
if (arg instanceof TypeLiteral typeLiteral) {
return typeLiteral.getType();
}
}
return null;
}
/**
* Extracts the cause exception class from a Hamcrest matcher expression.
*
* Supported Hamcrest matchers: -
* org.hamcrest.Matchers.instanceOf(ExceptionClass.class) -
* org.hamcrest.Matchers.isA(ExceptionClass.class)
*
* Unsupported matchers (will return null): - any(Class.class) - notNullValue()
* - Custom matchers
*
* @param causeArg the expression passed to expectCause()
* @return the class literal expression, or null if the matcher is not supported
*/
private Expression extractCauseClass(Expression causeArg) {
if (causeArg instanceof MethodInvocation methodInv) {
String methodName = methodInv.getName().getIdentifier();
if (("instanceOf".equals(methodName) || "isA".equals(methodName)) && !methodInv.arguments().isEmpty()) {
// Extract the class literal argument
Expression arg = (Expression) methodInv.arguments().get(0);
return arg;
}
}
return null;
}
private String generateUniqueVariableName(String baseName, Collection<String> usedNames) {
if (!usedNames.contains(baseName)) {
return baseName;
}
int counter = 1;
String candidateName;
do {
candidateName = baseName + counter;
counter++;
} while (usedNames.contains(candidateName));
return candidateName;
}
@Override
public String getPreview(boolean afterRefactoring) {
if (afterRefactoring) {
return """
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testException() {
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
throw new IllegalArgumentException("Invalid argument");
});
assertEquals("Invalid argument", exception.getMessage());
}
}
"""; //$NON-NLS-1$
}
return """
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
public class MyTest {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void testException() {
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("Invalid argument");
throw new IllegalArgumentException("Invalid argument");
}
}
"""; //$NON-NLS-1$
}
@Override
public String toString() {
return "RuleExpectedException"; //$NON-NLS-1$
}
private static class ExpectedExceptionInfo {
private MethodInvocation expectCall;
private MethodInvocation expectMessageCall;
private MethodInvocation expectCauseCall;
private int firstExpectStatementIndex = -1;
private int lastExpectStatementIndex = -1;
MethodInvocation getExpectCall() {
return expectCall;
}
void setExpectCall(MethodInvocation expectCall) {
this.expectCall = expectCall;
}
MethodInvocation getExpectMessageCall() {
return expectMessageCall;
}
void setExpectMessageCall(MethodInvocation expectMessageCall) {
this.expectMessageCall = expectMessageCall;
}
MethodInvocation getExpectCauseCall() {
return expectCauseCall;
}
void setExpectCauseCall(MethodInvocation expectCauseCall) {
this.expectCauseCall = expectCauseCall;
}
int getFirstExpectStatementIndex() {
return firstExpectStatementIndex;
}
void setFirstExpectStatementIndex(int index) {
this.firstExpectStatementIndex = index;
}
int getLastExpectStatementIndex() {
return lastExpectStatementIndex;
}
void setLastExpectStatementIndex(int index) {
this.lastExpectStatementIndex = index;
}
void updateStatementIndices(int index) {
if (firstExpectStatementIndex == -1) {
firstExpectStatementIndex = index;
}
lastExpectStatementIndex = index;
}
}
}