Spring与Mockito组合单元测试简单实用

85344193 2011-11-23

今天有点空了,把单元测试的那点事整理了一下。个人觉得Spring应用和Mockito结合做单元测试简单实用,分享出来供参考。

这里不啰嗦单元测试的重要性...。很多应用是基于Spring,而Mockito简单易用易上手,所以就把Spring和Mockito组合做单元测试,Mocked对象也交给Spring统一管理。好处至少有:使单元测试类的环境和应用实际的环境保持一致性。开发人员就不用另外增加额外的配置,也可以少一些代码。单元测试类通过了,相应的应用类也就ok(Spring的相关配置也ok)。

1,为了把Mock对象也纳入Spring。需继承DependencyInjectionTestExecutionListener并增加Mock对象和注入依赖对象为Mock对象。

代码:

public class MockitoDependencyInjectionTestExecutionListener extends DependencyInjectionTestExecutionListener {
    private static final Map<String, MockObject> mockObject   = new HashMap<String, MockObject>();
    private static final List<Field>             injectFields = new ArrayList<Field>();
    @Override
    protected void injectDependencies(final TestContext testContext) throws Exception {
        super.injectDependencies(testContext);
        init(testContext);
    }
    protected void injectMock(final TestContext testContext) throws IllegalArgumentException, IllegalAccessException, InvocationTargetException {
        AutowireCapableBeanFactory beanFactory = testContext.getApplicationContext().getAutowireCapableBeanFactory();
        for (Field field : injectFields) {
            Object o = beanFactory.getBean(field.getName(), field.getType());
            if (null != o) {
                Method[] methods = o.getClass().getDeclaredMethods();
                for (Method method : methods) {
                    if (method.getName().startsWith("set")) {
                        for (Iterator it = mockObject.keySet().iterator(); it.hasNext();) {
                            String key = (String) it.next();
                            if (method.getName().equalsIgnoreCase("set" + key)) {
                                method.invoke(o, mockObject.get(key).getObj());
                                break;
                            }
                        }
                    }
                }
            }
        }
    }

    private void init(final TestContext testContext) throws Exception {
        Object bean = testContext.getTestInstance();
        Field[] fields = bean.getClass().getDeclaredFields();
        for (Field field : fields) {
            Annotation[] annotations = field.getAnnotations();
            for (Annotation antt : annotations) {
                if (antt instanceof org.mockito.Mock) {
                    // 注入mock实例
                    MockObject obj = new MockObject();
                    obj.setType(field.getType());
                    obj.setObj(mock(field.getType()));
                    field.setAccessible(true);
                    field.set(bean, obj.getObj());
                    mockObject.put(field.getName(), obj);
                } else if (antt instanceof Autowired) {
                    // 只对autowire重新注入
                    injectFields.add(field);
                }
            }
        }
        for (Field field : injectFields) {
            field.setAccessible(true);
            Object object = field.get(bean);
            if (object instanceof Proxy) {
                // 如果是代理的话,找到真正的对象
                Class targetClass = AopUtils.getTargetClass(object);
                if (targetClass == null) {
                    // 可能是远程实现
                    return;
                }
                Field[] targetFields = targetClass.getDeclaredFields();
                for (int i = 0; i < targetFields.length; i++) {
                    // 针对每个需要重新注入的字段
                    for (Map.Entry<String, MockObject> entry : mockObject.entrySet()) {
                        // 针对每个mock的字段
                        if (targetFields[i].getName().equals(entry.getKey())) {
                            targetFields[i].setAccessible(true);
                            targetFields[i].set(getTargetObject(object, entry.getValue().getType()),
                                                entry.getValue().getObj());
                        }
                    }
                }
            } else {
                injectMock(testContext);
            }
        }
    }

    protected <T> T getTargetObject(Object proxy, Class<T> targetClass) throws Exception {
        if (AopUtils.isJdkDynamicProxy(proxy)) {
            return (T) ((Advised) proxy).getTargetSource().getTarget();
        } else {
            return (T) proxy; // expected to be cglib proxy then, which is simply a specialized class
        }
    }

    public static class MockObject {
        private Object   obj;
        private Class<?> type;

        public MockObject(){
        }

        public Object getObj() {
            return obj;
        }

        public void setObj(Object obj) {
            this.obj = obj;
        }

        public Class<?> getType() {
            return type;
        }

        public void setType(Class<?> type) {
            this.type = type;
        }
    }

    public static Map<String, MockObject> getMockobject() {
        return mockObject;
    }

    public static List<Field> getInjectfields() {
        return injectFields;
    }
}

2,单元测试继承AbstractJUnit4SpringContextTests类。这里用一个抽象类把必须的注解上。

代码:

@RunWith(SpringJUnit4ClassRunner.class)
@TestExecutionListeners({ MockitoDependencyInjectionTestExecutionListener.class })
public abstract class BaseTestCase extends AbstractJUnit4SpringContextTests {

}

3,测试用例demo。注意在MockitoDependencyInjectionTestExecutionListener中已经mock对象了,不要再用MockitoAnnotations.initMocks(this)。

代码:

@ContextConfiguration(locations = { "/applicationContext.xml" })
public class DemoServiceTest extends BaseTestCase {

    @Autowired
    private DemoService demoService;

    @Mock
    private DemoDao     demoDao;

    @Test
    public void testGetResults() {
        List<Model> list = new ArrayList<Model>();
        Model model = new Model("id", "name");
        list.add(model);

        // 先设置预期
        when(demoDao.getModel("id")).thenReturn(model);
        // when(demoDao.getResults()).thenAnswer(new Answer<List<Model>>() {
        //
        // @Override
        // public List<Model> answer(InvocationOnMock invocation) throws Throwable {
        // List<Model> list = new ArrayList<Model>();
        // Model model = new Model("id", "name");
        // list.add(model);
        // return list;
        // }
        //
        // });

        Model m = demoService.getModel("id");
        assertSame(m, model);
        // List<Model> result = demoService.getResults();
        // assertTrue(result.size() == 1);
    }
}

相关推荐