miller
发布于

基于实体类生成 db DDL

原文: https://blog.csdn.net/sunnyzyq/article/details/131597836

package com.xxxx.mdm;


import java.io.*;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.math.BigDecimal;
import java.time.LocalDate;
import java.util.*;
import java.util.function.Function;

/**
 * @author Yuanqiang.Zhang
 * @since 2023/7/5
 */
public class ClassUtils {

    public static void main(String[] args) throws Exception {
        Table table = getTable(AuditOperateLog.class);
        System.out.println(table.sql());
    }

    /**
     * Java类转化为数据库表对象
     *
     * @param c   java实体类
     * @param <T> 泛型
     * @return 数据库表对象
     */
    private static <T> Table getTable(Class<T> c) {
        List<String> codes = getClassSourceCodes(c);
        Map<String, Column> columnMap = getFieldColumnMap(c, codes);
        Class<? super T> superclass = c.getSuperclass();
        if (Objects.nonNull(superclass)) {
            if (superclass != Object.class) {
                List<String> sourceCodes = getClassSourceCodes(c);
                Map<String, Column> superColumnMap = getFieldColumnMap(superclass, sourceCodes);
                columnMap.putAll(superColumnMap);
            }
        }
        if (columnMap.isEmpty()) {
            return null;
        }
        // 获取主键(如果包含名为 id 的属性,则以 id 为主键,否则,则取类中第一个属性为主键)
        Column key;
        if (columnMap.containsKey("id")) {
            key = columnMap.get("id");
        } else {
            key = columnMap.entrySet().stream().findFirst().get().getValue();
        }
        // 剩下的则是其他属性
        LinkedList<Column> columns = new LinkedList<>(columnMap.values());
        if (!columns.get(0).getName().equals(key.getName())) {
            Column copyKey = key.copy();
            columns.remove(key);
            columns.addFirst(copyKey);
        }
        Table table = new Table();
        table.setName(getSqlName(c.getSimpleName()));
        table.setComment(getTableComment(codes));
        table.setColumns(columns);
        return table;
    }

    private static <T> Map<String, Column> getFieldColumnMap(Class<T> c, List<String> codes) {
        Map<String, String> fieldCommentMap = getClassFieldComment(codes, c.getSimpleName());
        Field[] fields = c.getDeclaredFields();
        Map<String, Column> columnMap = new LinkedHashMap<>();
        for (Field field : fields) {
            if (!Modifier.isStatic(field.getModifiers())) {
                String fieldName = field.getName();
                String columnName = getSqlName(fieldName);
                Column column = new Column();
                column.setName(columnName);
                column.setType(getSqlType(field));
                column.setComment(fieldCommentMap.get(fieldName));
                columnMap.put(fieldName, column);
            }
        }
        return columnMap;
    }

    /**
     * 获取 mysql 数据类型
     *
     * @param field java 属性
     * @return mysql的
     */
    private static String getSqlType(Field field) {
        Class<?> type = field.getType();
        String name = getSqlName(field.getName());
        if (type == String.class) {
            return "varchar(255)";
        } else if (type == int.class || type == Integer.class) {
            if (Constant.TINYINT_WORDS.contains(name)) {
                return "tinyint(4)";
            }
            if (name.startsWith("is_") || name.endsWith("_type") || name.endsWith("_status")) {
                return "tinyint(4)";
            }
            return "int(11)";
        } else if (type == long.class || type == Long.class) {
            return "bigint(20)";
        } else if (type == double.class || type == Double.class) {
            // 最大为 (53, 15)
            return "double(20,2)";
        } else if (type == boolean.class || type == Boolean.class) {
            return "boolean";
        } else if (type == Date.class || type == LocalDate.class) {
            return "datetime";
        } else if (type == BigDecimal.class) {
            // 最大为 decimal(65, 30)
            return "decimal(20,2)";
        } else {
            // 其他类型则不处理
            return null;
        }
    }

    /**
     * 获取类文件源码内容
     *
     * @param clazz 类对象
     * @return 类源码内容
     */
    private static List<String> getClassSourceCodes(Class<?> clazz) {
        String classPath = clazz.getName().replace(Constant.POINT, Constant.LEFT_LINE) + Constant.POINT_JAVA;
        String filePath;
        if (new File(Constant.POM_XML).exists()) {
            filePath = Constant.SRC_MAIN_JAVA + classPath;
        } else {
            filePath = Constant.SRC + classPath;
        }
        File file = new File(filePath);
        List<String> lines = new ArrayList<>();
        try (BufferedReader reader = new BufferedReader(new FileReader(file))) {
            String line;
            while ((line = reader.readLine()) != null) {
                lines.add(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return lines;
    }

    /**
     * 获取表备注(默认取类文档注释第一行文字)
     *
     * @param codes 类源码内容
     * @return 表备注
     */
    private static String getTableComment(List<String> codes) {
        if (codes.isEmpty()) {
            return null;
        }
        // 按规范来说,第一个格式为 [ * xxx] 的行,即为表的文档注释
        String comment = null;
        for (String line : codes) {
            String trimLine = line.trim();
            if (trimLine.startsWith(Constant.STAR) && line.length() > 1) {
                comment = trimLine.substring(1).trim();
                break;
            }
        }
        if (Objects.isNull(comment)) {
            return null;
        }
        // 按SQL表备注习惯来说,一般称为[xxx表],因此如果备注不以"表"字结尾,我们这里补加上
        if (comment.endsWith(Constant.TABLE)) {
            return comment;
        }
        return comment + Constant.TABLE;
    }

    /**
     * 获取类属性注释(文档注释、多行注释、单行注释、行尾注释)
     *
     * @param codes 类源码内容
     * @return 表备注
     */
    private static Map<String, String> getClassFieldComment(List<String> codes, String className) {
        boolean fieldLineStart = false;
        Map<String, String> filedCommentMap = new LinkedHashMap<>();
        String comment = null;
        boolean commentStart = false;
        for (int i = 0; i < codes.size(); i++) {
            String code = codes.get(i).trim();
            if (!fieldLineStart) {
                if (code.contains(String.format("class %s", className))) {
                    fieldLineStart = true;
                }
                continue;
            }
            if (commentStart) {
                String str = trim(code, Constant.CHAR_STAR).trim();
                if (!str.isEmpty()) {
                    comment = str;
                    commentStart = false;
                    continue;
                }
            } else {
                if (code.startsWith("/*")) {
                    commentStart = true;
                    if (code.endsWith("*/")) {
                        comment = getSingleLineDocComment(code);
                        commentStart = false;
                        continue;
                    }
                }
                if (code.startsWith("//")) {
                    comment = code.substring(2).trim();
                    commentStart = false;
                    continue;
                }
            }
            // 如果进入方法区域,则停止解析
            if (code.contains("(") && code.contains(")") && code.endsWith("{")) {
                return filedCommentMap;
            }
            // 属性判定
            if (code.contains(";") && !code.startsWith("//")) {
                String fileLine = code;
                if (code.contains("//")) {
                    String[] arr = code.split("//");
                    if (arr.length > 1 && Objects.isNull(comment)) {
                        comment = arr[1].trim();
                    }
                    fileLine = arr[0];
                }
                String fieldName = getFieldName(fileLine);
                filedCommentMap.put(fieldName, comment);
                comment = null;
                commentStart = false;
            }
        }
        return filedCommentMap;
    }

    private static String getFieldName(String s) {
        String[] arr = s.split(" ");
        for (int i = arr.length - 1; i >= 0; i--) {
            String str = arr[i];
            if (!str.isEmpty() && !str.equals(";")) {
                return str.replace(";", "");
            }
        }
        return "";
    }

    /**
     * 获取单行文档注释
     *
     * @param code 单行注释
     * @return 注释内容
     */
    private static String getSingleLineDocComment(String code) {
        // code 应该长这样(/*xxx*/),我们可以先取中间的xxx
        String xxx = code.substring(2, code.length() - 2);
        // xxx 首尾可能还包含 *,需要进一步将 xxx 首尾多余的 * 去掉,剩下的就是注释
        return trim(xxx, Constant.CHAR_STAR).trim();
    }

    /**
     * 去掉首尾指定的符号(该方法参考的是String类中的trim()方法)
     *
     * @param s 字符串
     * @param c 去掉的指定符号
     * @return 不包含首尾指定字符
     */
    public static String trim(String s, char c) {
        int var1 = s.length();
        int var2 = 0;
        char[] var3;
        for (var3 = s.toCharArray(); var2 < var1 && var3[var2] <= c; ++var2) {
        }
        while (var2 < var1 && var3[var1 - 1] <= c) {
            --var1;
        }
        return var2 <= 0 && var1 >= s.length() ? s : s.substring(var2, var1);
    }

    /**
     * 根据驼峰命名获取下划线命名
     *
     * @param javaName java的驼峰命名
     * @return sql中的下划线命名
     */
    private static String getSqlName(String javaName) {
        StringBuilder sqlName = new StringBuilder();
        for (int i = 0; i < javaName.length(); i++) {
            char c = javaName.charAt(i);
            if (Character.isUpperCase(c)) {
                if (i > 0) {
                    sqlName.append(Constant.UNDERLINE);
                }
                sqlName.append(Character.toLowerCase(c));
            } else {
                sqlName.append(c);
            }
        }
        return sqlName.toString();
    }

    /**
     * 常量类
     */
    static class Constant {
        private static final String POM_XML = "pom.xml";
        private static final String POINT = ".";
        private static final String LEFT_LINE = "/";
        private static final String POINT_JAVA = ".java";
        private static final String SRC = "src/";
        private static final String SRC_MAIN_JAVA = "src/main/java/";
        private static final String STAR = "*";
        private static final String TABLE = "表";
        private static final String UNDERLINE = "_";
        private static final char CHAR_STAR = '*';
        private static final List<String> TINYINT_WORDS = Arrays.asList("deleted", "sex", "age", "status", "type", "state");
    }

    /**
     * SQL表
     */
    static class Table {

        private String tablePrefix;
        private String columnPrefix;
        private String name;
        private String comment;
        private LinkedList<Column> columns;

        public String getName() {
            return name;
        }

        public void setName(String name) {
            this.name = name;
        }

        public String getComment() {
            return comment;
        }

        public void setComment(String comment) {
            this.comment = comment;
        }

        public List<Column> getColumns() {
            return columns;
        }

        public void setColumns(LinkedList<Column> columns) {
            this.columns = columns;
        }

        public String getTablePrefix() {
            return tablePrefix;
        }

        public void setTablePrefix(String tablePrefix) {
            this.tablePrefix = tablePrefix;
        }

        public String getColumnPrefix() {
            return columnPrefix;
        }

        public void setColumnPrefix(String columnPrefix) {
            this.columnPrefix = columnPrefix;
        }

        public String sql() {
            if (Objects.isNull(columns)) {
                columns = new LinkedList<>();
            }
            List<String> lines = new ArrayList<>();
            String tableName;
            if (Objects.nonNull(tablePrefix)) {
                tableName = tablePrefix + name;
            } else {
                tableName = name;
            }
            lines.add(String.format("CREATE TABLE `%s` (", tableName));
            Column key = columns.get(0);
            String type = key.getType();
            if (type.startsWith("tinyint") || type.startsWith("int") || type.startsWith("bigint")) {
                key.setIncreased(true);
                String comment = key.getComment();
                if (Objects.isNull(comment) || comment.isEmpty()) {
                    key.setComment("主键");
                }
            }
            for (int i = 0; i < columns.size(); i++) {
                if (columns.get(i).isEffective()) {
                    lines.add(columns.get(i).sql(columnPrefix));
                }
            }
            if (key.isEffective()) {
                lines.add(String.format("    PRIMARY KEY (`%s`)", key.getName()));
            }
            StringBuffer sb = new StringBuffer();
            sb.append(") ENGINE=InnoDB");
            if (key.isEffective() && key.isIncreased()) {
                sb.append(" AUTO_INCREMENT=1");
            }
            sb.append(" DEFAULT CHARSET=utf8");
            if (Objects.isNull(comment)) {
                sb.append(" COMMENT='';");
            } else {
                sb.append(" COMMENT='").append(comment).append("';");
            }
            lines.add(sb.toString());
            return String.join("\n", lines);
        }

        public <T, R> Table key(SerializableFunction<T, R> fn) {
            String newKeyName = getSqlName(getFieldName(fn));
            Column find = null;
            for (Column column : columns) {
                if (column.getName().equals(newKeyName)) {
                    find = column;
                }
            }
            if (Objects.nonNull(find)) {
                Column copy = find.copy();
                columns.remove(find);
                columns.addFirst(copy);
            }
            return this;
        }

        public <T, R> Table column(SerializableFunction<T, R> fn, String type, String comment, String defaultValue) {
            String fieldName = getFieldName(fn);
            String columnName = getSqlName(fieldName);
            for (Column column : columns) {
                if (column.getName().equals(columnName)) {
                    if (Objects.nonNull(type)) {
                        column.setType(type);
                    }
                    if (Objects.nonNull(comment)) {
                        column.setComment(comment);
                    }
                    if (Objects.nonNull(defaultValue)) {
                        column.setDefaultValue(defaultValue);
                    }
                    break;
                }
            }
            return this;
        }

        public <T, R> Table invalid(SerializableFunction<T, R>... fns) {
            if (Objects.nonNull(fns)) {
                for (SerializableFunction<T, R> fn : fns) {
                    String fieldName = getFieldName(fn);
                    String columnName = getSqlName(fieldName);
                    for (Column column : columns) {
                        if (column.getName().equals(columnName)) {
                            column.setEffective(false);
                            break;
                        }
                    }
                }
            }
            return this;
        }

        private static <T> String getFieldName(SerializableFunction<T, ?> func) {
            // 通过获取对象方法,判断是否存在该方法
            String getter = null;
            try {
                Method method = func.getClass().getDeclaredMethod("writeReplace");
                method.setAccessible(Boolean.TRUE);
                // 利用jdk的SerializedLambda 解析方法引用
                SerializedLambda serializedLambda = (SerializedLambda) method.invoke(func);
                getter = serializedLambda.getImplMethodName();
            } catch (Exception e) {
                e.printStackTrace();
            }
            if (Objects.isNull(getter)) {
                return null;
            }
            String fieldName;
            if (getter.startsWith("get")) {
                fieldName = getter.substring(3);
            } else if (getter.startsWith("is")) {
                fieldName = getter.substring(2);
            } else {
                fieldName = getter;
            }
            return Character.toLowerCase(fieldName.charAt(0)) + fieldName.substring(1);
        }

    }

    /**
     * SQL表字段
     */
    static class Column {

        private String name;
        private String type;
        private String comment;
        private String defaultValue;
        private boolean effective = true;
        private boolean increased;

        public String getName() {
            return name;
        }

        public void setName(String name) {
            this.name = name;
        }

        public String getType() {
            return type;
        }

        public void setType(String type) {
            this.type = type;
        }

        public String getComment() {
            return comment;
        }

        public void setComment(String comment) {
            this.comment = comment;
        }

        public String getDefaultValue() {
            return defaultValue;
        }

        public void setDefaultValue(String defaultValue) {
            this.defaultValue = defaultValue;
        }

        public boolean isEffective() {
            return effective;
        }

        public void setEffective(boolean effective) {
            this.effective = effective;
        }

        public boolean isIncreased() {
            return increased;
        }

        public void setIncreased(boolean increased) {
            this.increased = increased;
        }

        public Column copy() {
            Column copy = new Column();
            copy.setName(this.name);
            copy.setType(this.getType());
            copy.setComment(this.comment);
            copy.setDefaultValue(this.defaultValue);
            copy.setEffective(this.effective);
            copy.setIncreased(this.increased);
            return copy;
        }

        public String sql(String columnPrefix) {
            StringBuffer sb = new StringBuffer();
            // 字段名称
            sb.append("`");
            if (Objects.nonNull(columnPrefix)) {
                sb.append(columnPrefix);
            }
            sb.append(name).append("`");
            // 字段类型
            sb.append(" ").append(type);
            // 是否自增
            if (increased) {
                sb.append(" NOT NULL AUTO_INCREMENT");
            } else {
                // 默认值
                if (Objects.isNull(defaultValue)) {
                    if (type.startsWith("tinyint") || type.startsWith("int") || type.startsWith("bigint")) {
                        sb.append(" DEFAULT 0");
                    } else if (type.startsWith("varchar")) {
                        sb.append(" DEFAULT ''");
                    } else {
                        sb.append(" DEFAULT NULL");
                    }
                } else {
                    sb.append(" DEFAULT '").append(defaultValue).append("'");
                }
            }
            // 备注
            if (Objects.isNull(comment)) {
                sb.append(" COMMENT ''");
            } else {
                sb.append(" COMMENT '").append(comment).append("'");
            }
            return "    " + sb + ",";
        }
    }

    @FunctionalInterface
    public interface SerializableFunction<T, R> extends Function<T, R>, Serializable {
    }

}

浏览 (260)
点赞
收藏
评论