迷你版Spring MVC 实现

项目建立

首先,建立一个空Maven项目,在java目录下,建立以下包结构。

.....
  |--com.it
    |-annotation
    |-dao
    |-service
    |-web
    DispatchServlet

接下来,引入servlet依赖

<dependency>
    <groupId>javax.servlet</groupId>
    <artifactId>javax.servlet-api</artifactId>
    <version>3.1.0</version>
 </dependency>

在annotation增加如下注解:

@Autowired.java

@Documented
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Autowired {

    String value();

}

@Controller.java

@Documented
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Controller {

    String value() default "";

}

@Repository.java

@Documented
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Repository {

    String value() default "";

}

@RequestMapping.java

@Documented
@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RequestMapping {

    String value();

}

@Service.java

@Documented
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Service {

    String value() default "";

}

编写核心Servlet

在Spring MVC中,DispatcherServlet是核心,因此我们也建立一个自己的DispatcherServlet来接收自己的请求。

DispatcherServlet.java

@WebServlet(urlPatterns = "/*", loadOnStartup = 0,
        initParams = {@WebInitParam(name = "base-package", value = "com.it")})
public class DispatchServlet extends HttpServlet {

    private static final String EMPTY = "";

    /** 扫描的基础包*/
    private String basePackage = EMPTY;
 	/** 扫描的基础包下面的类的全类名*/
    private List<String> packagesName = new ArrayList<>();
 	/** key:注解里的值(默认为类名第一个字母小写)  value :类的实例化对象*/
    private Map<String, Object> instanceMap = new HashMap<>();
 	/** key:类的全类名  value:注解里的值(默认为类名第一个字母小写)*/
    private Map<String, String> nameMap = new HashMap<>();
 	/** key:请求路径  value: 所调用的方法*/
    private Map<String, Method> urlMethodMap = new HashMap<>();
	 /** key:controller里的方法示例  value: 当前类的全类名*/
    private Map<Method, String> methodPackageMap = new HashMap<>();


}

接下来,需要在serlvet初始化的时候,去扫描我们注解标识的类的信息,总体步骤如下:

1)扫描制定包下所有的类,进行初始化,记录类的全类名。

2)扫描所有标注@Controller/@Service/@Repository的注解类,并且记录全类名与实例的映射关系。

3)对于有@Autowired注解的字段,我们需要为其注入相应类的实例,为了方便,在这里统一用类的名称去获取,当然在spring 中可以通过类型等种种手段去注入。这里所有的类名默认转换成了 类名首字母小写的形式。

4)对于所有@RequestMapping注解的形式,需要将请求路径和相应的方法做关联记录,这样我们可以通过请求路径映射到我们的controller方法上。

Servlet初始化

@Override
public void init(ServletConfig config) throws ServletException {
    basePackage = config.getInitParameter("base-package");
    try {
        scanBasePackage(basePackage);
        instance(packagesName);
        ioc();
        handlerUrlMethod();
    } catch (Exception e) {
        e.printStackTrace();
        System.out.println("框架加载失败");
    }
}

第一步 ,包名扫描

 private void scanBasePackage(String basePackage) {
     Optional<URL> resource = Optional.ofNullable(this.getClass().getClassLoader().getResource(basePackage.replaceAll("\\.", "/")));
     String path = resource.map(URL::getPath).orElse("");
     File basePackageFile = new File(path);
     Optional<File[]> files = Optional.ofNullable(basePackageFile.listFiles());
     files.ifPresent(file -> {
         for (File temp : file) {
             if (temp.isDirectory()) {
                 scanBasePackage(basePackage + "." + temp.getName());
             }
             if (temp.isFile()) {
                 packagesName.add(basePackage + "." + temp.getName().split("\\.")[0]);
             }
         }
     });
 }

第二步,实例化类

private void instance(List<String> packagesNames) {
    if (packagesNames.isEmpty()) {
        return;
    }
    packagesNames.forEach(packageName -> {
        try {
            Class clazz = Class.forName(packageName);
            String subName = clazz.getName().substring(clazz.getName().lastIndexOf(".")).replaceAll("\\.", EMPTY);
            String clazzName = subName.substring(0, 1).toLowerCase() + subName.substring(1);
            if (clazz.isAnnotationPresent(Controller.class)) {
                Controller controller = (Controller) clazz.getAnnotation(Controller.class);
                String controllerName = controller.value();
                if (EMPTY.equals(controllerName)) {
                    controllerName = clazzName;
                }
                instanceMap.put(controllerName, clazz.newInstance());
                nameMap.put(packageName, controllerName);
            }
            if (clazz.isAnnotationPresent(Service.class)) {
                Service service = (Service) clazz.getAnnotation(Service.class);
                String serviceName = service.value();
                if (EMPTY.equals(serviceName)) {
                    serviceName = clazzName;
                }
                instanceMap.put(serviceName, clazz.newInstance());
                nameMap.put(packageName, serviceName);
            }
            if (clazz.isAnnotationPresent(Repository.class)) {
                Repository repository = (Repository) clazz.getAnnotation(Repository.class);
                String respositoryName = repository.value();
                if (EMPTY.equals(respositoryName)) {
                    respositoryName = clazzName;
                }
                instanceMap.put(respositoryName, clazz.newInstance());
                nameMap.put(packageName, respositoryName);
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println("类加载失败");
        }
    });
}

第三步,注入类的实例

private void ioc() throws IllegalAccessException {
    for (Map.Entry<String, Object> instance : instanceMap.entrySet()) {
        Field[] declaredFields = instance.getValue().getClass().getDeclaredFields();
        for (Field declaredField : declaredFields) {
            if (declaredField.isAnnotationPresent(Autowired.class)) {
                String value = declaredField.getAnnotation(Autowired.class).value();
                declaredField.setAccessible(true);
                declaredField.set(instance.getValue(), instanceMap.get(value));
            }
        }
    }
}

第四步,处理请求路径与方法的映射

private void handlerUrlMethod() {
    if (packagesName.isEmpty()) {
        return;
    }
    packagesName.forEach(packageName -> {
        try {
            Class clazz = Class.forName(packageName);
            if (clazz.isAnnotationPresent(Controller.class)) {
                Method[] methods = clazz.getMethods();
                StringBuffer baseUrl = new StringBuffer();
                if (clazz.isAnnotationPresent(RequestMapping.class)) {
                    RequestMapping controllerClazzMapping = (RequestMapping) clazz.getAnnotation(RequestMapping.class);
                    baseUrl.append(controllerClazzMapping.value());
                }
                for (Method method : methods) {
                    if (method.isAnnotationPresent(RequestMapping.class)) {
                        RequestMapping methodMapping = method.getAnnotation(RequestMapping.class);
                        baseUrl.append(methodMapping.value());
                        urlMethodMap.put(baseUrl.toString(), method);
                        methodPackageMap.put(method, packageName);
                    }
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println("获取映射失败");
        }
    });
}

编写GET/POST方法

编写GET/POST方法,使请求映射到我们的Controller上

@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
    doPost(req, resp);
}


@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
    String uri = req.getRequestURI();
    String contextPath = req.getContextPath();
    String path = uri.replace(contextPath, EMPTY);
    Optional<Method> method = Optional.ofNullable(urlMethodMap.get(path));
    method.ifPresent(m -> {
        String packageName = methodPackageMap.get(m);
        String controllerName = nameMap.get(packageName);
        Object controllerObj = instanceMap.get(controllerName);
        m.setAccessible(true);
        try {
            m.invoke(controllerObj);
        } catch (Exception e) {
            e.printStackTrace();
        }
    });
}

编写测试

UserDao.java

@Repository
public class UserDao {

    public void insert(){
        System.out.println("insert user:");
    }

}

UserService.java

@Service
public class UserService {

    @Autowired("userDao")
    private UserDao userDao;

    public void insert() {
        userDao.insert();
    }
}

UserController.java

@Controller
@RequestMapping("/user")
public class UserController {

    @Autowired("userService")
    private UserService userService;

    @RequestMapping("/insert")
    public void insert() {
        userService.insert();
    }

}

编写完成后,在tomcat启动,访问 IP:PORT + /user/insert 后,可以在控制台看到 insert user:的输出,说明我们的迷你版Spring MVC可以正常工作啦。

代码参考:mymvc

01-29 20:11