理解工具的最好的方法就是动手实践

Lexer文件

lexer grammar MethodTaskLexer;

OPEN : '(' -> pushMode(INSIDE)
;

// bean.method
BEAN_DOT_METHOD
    : [a-zA-Z]+[0-9a-zA-Z_]*'.'[a-zA-Z]+[0-9a-zA-Z_]*
    ;

mode INSIDE;

CLOSE : ')' -> popMode
;

// 参数值
// 允许两种参数类型 - 一种是占位符;另一种是指定的JDK原生类型
PARAM_VALUE
    : STRING    // 字符串
    | DIGIT+    // 数字
    | DIGIT+ '.' DIGIT*     // 浮点
    | '.' DIGIT+    // 浮点
    | 'true'|'false'   // bool
//    | '\''
    ;

// 占位符 -- 例如 ${instance} 表示instance参数
PARAM_PLACEHOLDER
    : '${'[_0-9a-zA-Z]+'}'
    ;

PRIMITIVE_TYPE_EXT
    : 'decimal'     // for BigDecimal
    ;

// jdk原生类型 以及 classpath一般正则匹配
PRIMITIVE_TYPE
    : 'boolean'
    | 'int'
    | 'String'
    | 'float'
    | 'long'
    | 'double'
    | 'short'
    | 'byte'
    | 'char'
    | 'Boolean'
    | 'Integer'
    | 'Float'
    | 'Long'
    | 'Double'
    | 'Short'
    | 'Byte'
    | 'Character'
    ;

CLASS_PATH
    : [0-9a-zA-Z_]+(.[0-9a-zA-Z_]+)*'.'[a-zA-Z]+[0-9a-zA-Z_]*
    ;

STRING
//    : '"' (ESC|.)*? '"'    // 匹配在双引号中的任意字符
    : '"' .*? '"'    // 匹配在双引号中的任意字符
    ;

COMMA: ',' ;

WS : [ \n\t\r]+ -> skip;

fragment
DIGIT: [0-9] ;
ESC : '\\"' | '\\\\' ;    // 匹配字符\"和\\

Parser文件

parser grammar MethodTaskParser;
options { tokenVocab=MethodTaskLexer; }

//prog
//    : expr +
//    ;

expr
    : BEAN_DOT_METHOD OPEN CLOSE
    | BEAN_DOT_METHOD OPEN paramDefGroup CLOSE
    ;

paramDefGroup
    : paramDef (COMMA paramDef)*
    ;

paramDef
    : CLASS_PATH PARAM_VALUE
    | CLASS_PATH PARAM_PLACEHOLDER
    | PRIMITIVE_TYPE PARAM_VALUE
    | PRIMITIVE_TYPE PARAM_PLACEHOLDER
    | PRIMITIVE_TYPE_EXT PARAM_VALUE
    | PRIMITIVE_TYPE_EXT PARAM_PLACEHOLDER
    ;

lexerparser也可以写在一个文件里面,不过antlr存在一些限制,如果写在一个文件中就不能在定义token的时候使用词法模型特性了;

模版代码生成 & 填充业务逻辑

本文中,我们使用listener模式计算结果,antlr除了支持listener模式之外也支持visitor模式, 第一步是先使用antlr生成相关的基础设施,我使用了intellij ideaantlr的插件(也可以命令行,这里需要注意的是antlr版本需要和插件版本保持一致,否则会失败):

然后就可以看到相关的文件:

把相关的java class拷贝到项目目录下, 继承 MethodTaskParserBaseListener,在基础设施的实现中会对生成的AST语法树进行深度优先遍历,因此我们就可以针对相关的语法做出对应的响应从而实现DSL的逻辑。

class Runner extends MethodTaskParserBaseListener {
        // 任务上下文
        private final JobContext jobContext;
        // 表达式原文
        private final String expression;

        private String beanName;
        private String methodName;
        private final List<Class> paramTypes = new ArrayList<>();
        private final List<Object> paramValues = new ArrayList<>();

        private boolean isWalkingPlaceHolder = false;

        private Object result;

        public Runner(JobContext jobContext, String expression) {
            super();
            this.jobContext = jobContext;
            this.expression = expression;
        }

        @Override
        public void enterExpr(ExprContext ctx) {
            //System.out.println("enterExpr");
            log.info("开始解析任务表达式, expr:{}, parsed_expr: {}", expression, ctx.toStringTree());

            String beanDotMethod = ctx.BEAN_DOT_METHOD().getText();
            String[] a = beanDotMethod.split("\\.");
            setBeanName(a[0]);
            setMethodName(a[1]);
        }

        @Override
        public void exitExpr(ExprContext ctx) {
            //System.out.println("exitExpr");
            //运行脚本
            try {
                // find bean
                Object obj = SpringContext.getBeanByName(beanName);
                if (null != obj) {
                    // find method
                    Method method;
                    try {
                        method = obj.getClass().getMethod(methodName, paramTypes.toArray(new Class[0]));
                    } catch (Exception e) {
                        throw new RuntimeException("找不到对应的方法", e);
                    }
                    // invoke method
                    result = method.invoke(obj, paramValues.toArray());
                } else {
                    // bean不存在
                    throw new RuntimeException("bean:" + beanName + "不存在");
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        @Override
        public void enterParamDef(ParamDefContext ctx) {
            //System.out.println("enterParamDef: " + ctx.toStringTree());

            Function<String, Object> converter = null;
            try {
                // 先解析类型
                if (ctx.CLASS_PATH() != null) {
                    Class<?> clazz = Class.forName(ctx.CLASS_PATH().getText());
                    paramTypes.add(clazz);
                    converter = (s) -> GsonUtil.fromJson(s, clazz);
                } else if (ctx.PRIMITIVE_TYPE() != null) {
                    String primitiveType = ctx.PRIMITIVE_TYPE().getText();
                    switch (primitiveType) {
                        case "boolean":
                            paramTypes.add(boolean.class);
                            converter = Boolean::parseBoolean;
                            break;
                        case "int":
                            paramTypes.add(int.class);
                            converter = Integer::parseInt;
                            break;
                        case "String":
                            paramTypes.add(String.class);
                            converter = (s) -> {
                                return isWalkingPlaceHolder ? s : s.substring(1, s.length() - 1);
                            };
                            break;
                        case "float":
                            paramTypes.add(float.class);
                            converter = Float::parseFloat;
                            break;
                        case "long":
                            paramTypes.add(long.class);
                            converter = Long::parseLong;
                            break;
                        case "double":
                            paramTypes.add(double.class);
                            converter = Double::parseDouble;
                            break;
                        case "short":
                            paramTypes.add(short.class);
                            converter = Short::parseShort;
                            break;
                        case "byte":
                            paramTypes.add(byte.class);
                            converter = Byte::parseByte;
                            break;
                        case "char":
                            paramTypes.add(char.class);
                            converter = (s) -> s.charAt(0);
                            break;
                        case "Boolean":
                            paramTypes.add(Boolean.class);
                            converter = Boolean::valueOf;
                            break;
                        case "Integer":
                            paramTypes.add(Integer.class);
                            converter = Integer::valueOf;
                            break;
                        case "Float":
                            paramTypes.add(Float.class);
                            converter = Float::valueOf;
                            break;
                        case "Long":
                            paramTypes.add(Long.class);
                            converter = Long::valueOf;
                            break;
                        case "Double":
                            paramTypes.add(Double.class);
                            converter = Double::valueOf;
                            break;
                        case "Short":
                            paramTypes.add(Short.class);
                            converter = Short::valueOf;
                            break;
                        case "Byte":
                            paramTypes.add(Byte.class);
                            converter = Byte::valueOf;
                            break;
                        case "Character":
                            paramTypes.add(Character.class);
                            converter = (s) -> Character.valueOf(s.charAt(0));
                            break;
                        default:
                            throw new RuntimeException("error, primitiveType:" + primitiveType + " not supported");
                    }
                } else if (ctx.PRIMITIVE_TYPE_EXT() != null) {
                    String primitiveTypeExt = ctx.PRIMITIVE_TYPE_EXT().getText();
                    if ("decimal".equals(primitiveTypeExt)) {
                        paramTypes.add(BigDecimal.class);
                        converter = BigDecimal::new;
                    } else {
                        throw new RuntimeException("error, primitiveTypeExt:" + primitiveTypeExt + " not supported");
                    }
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }

            // 再解析参数值
            if (ctx.PARAM_VALUE() != null) {
                String paramValue = ctx.PARAM_VALUE().getText();
                try {
                    paramValues.add(converter.apply(paramValue));
                } catch (Exception e) {
                    throw new RuntimeException("参数值解析异常", e);
                }
            } else if (ctx.PARAM_PLACEHOLDER() != null) {
                try {
                    isWalkingPlaceHolder = true;
                    String placeHolder = ctx.PARAM_PLACEHOLDER().getText();
                    if ("${instanceParam}".equalsIgnoreCase(placeHolder)) {
                        paramValues.add(converter.apply(jobContext.getInstanceParam()));
                    } else if ("${taskParam}".equalsIgnoreCase(placeHolder)) {
                        paramValues.add(converter.apply(jobContext.getTaskParam()));
                    } else {
                        throw new RuntimeException("error, placeholder " + placeHolder + " not supported");
                    }
                } catch (Exception e) {
                    throw new RuntimeException("参数值解析异常", e);
                } finally {
                    isWalkingPlaceHolder = false;
                }
            }
        }

        @Override
        public void exitParamDef(ParamDefContext ctx) {
            //System.out.println("exitParamDef: " + ctx.toStringTree());
        }

        @Override
        public void visitErrorNode(org.antlr.v4.runtime.tree.ErrorNode node) {
            log.error("解析任务表达式失败,msg:{}", node.toString());
            throw new RuntimeException("解析任务表达式失败,msg:" + node.toString());
        }

        public Object getResult() {
            return result;
        }

        void setBeanName(String beanName) {
            this.beanName = beanName;
        }

        void setMethodName(String methodName) {
            this.methodName = methodName;
        }
    }

这样表达式解析的雏形就基本上完成了。
antlr真是一件非常棒的工具,通过它的帮助你甚至不需要从0开始编写词法&语法解析就可以实现自己的语言!实现一个解析器简直是太大材小用了!从业务编程的视角来看,它可以成为一件强大的武器,例如帮助我们在配置化驱动和硬编码之间找到第三种解决方案,在不失扩展性的同时也不需要很高的系统复杂性!