import org.jmock.Mockery;
import org.jmock.integration.junit4.JMock;
import org.junit.internal.runners.InitializationError;
import org.junit.internal.runners.TestMethod;
import org.junit.internal.runners.MethodRoadie;
import org.junit.runner.notification.RunNotifier;
import org.junit.runner.Description;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;


public class JMockRunner extends JMock
{
    private List<Field> _mockedFields;

    public JMockRunner(Class<?> testClass) throws InitializationError {
        super(testClass);
        _mockedFields = findAnnotatedFields(testClass, Mocked.class);
        for (Field f : _mockedFields) {
            //requires full access
            f.setAccessible(true);
        }
    }

    @Override
    protected void invokeTestMethod(Method method, RunNotifier notifier) {
        Description description = methodDescription(method);
        Object test;
        try {
            test = createTest();
        } catch (InvocationTargetException e) {
            notifier.testAborted(description, e.getCause());
            return;
        } catch (Exception e) {
            notifier.testAborted(description, e);
            return;
        }
        TestMethod testMethod = wrapMethod(method);
        new MockingMethodRoadie(test, testMethod, notifier, description)
                .run();
    }

    private void defineMockFields(Object test) throws IllegalAccessException {
        Mockery mockery = mockeryOf(test);
        for (Field f : _mockedFields) {
            Object mocked = mockery.mock(f.getType());
            f.set(test, mocked);
        }
    }

    class MockingMethodRoadie extends MethodRoadie
    {
        Object _test;

        MockingMethodRoadie(Object test, TestMethod method, RunNotifier notifier, Description description) {
            super(test, method, notifier, description);
            _test = test;
        }

        @Override
        public void runBeforesThenTestThenAfters(Runnable test) {
            try {
                defineMockFields(_test);
            } catch (IllegalAccessException e) {
                throw new RuntimeException("Problem while trying to mock : ", e);
            }
            super.runBeforesThenTestThenAfters(test);
        }
    }

    public List<Field> findAnnotatedFields(Class clazz, Class<? extends Annotation> annotation) {
        List<Field> results = new ArrayList<Field> ();
        Class current = clazz;
        while (current != Object.class) {
            Field[] fields = current.getDeclaredFields();
            for (Field f : fields) {
                if (f.isAnnotationPresent(annotation)) {
                    results.add (f);
                }
            }
            current = current.getSuperclass();
        }

        return results;
    }

}

