From fa3f3204de3831c6734a78ddaeb85b52056f8ac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E6=99=93=E5=85=89?= Date: Mon, 15 Apr 2019 20:40:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BD=BF=E7=94=A8mycat?= =?UTF-8?q?=E8=87=AA=E5=A2=9E=E5=BA=8F=E5=88=97=E6=8F=92=E5=85=A5=E5=8C=85?= =?UTF-8?q?=E5=90=ABon=20duplicate=20key=20update=E5=AD=90=E5=8F=A5?= =?UTF-8?q?=E6=97=B6bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../java/io/mycat/route/util/RouterUtil.java | 132 ++++++++++-------- .../io/mycat/route/util/RouterUtilTest.java | 21 ++- 2 files changed, 91 insertions(+), 62 deletions(-) diff --git a/src/main/java/io/mycat/route/util/RouterUtil.java b/src/main/java/io/mycat/route/util/RouterUtil.java index 724ff37fc..6f98e5882 100644 --- a/src/main/java/io/mycat/route/util/RouterUtil.java +++ b/src/main/java/io/mycat/route/util/RouterUtil.java @@ -642,7 +642,12 @@ public static boolean processInsert(ServerConnection sc,SchemaConfig schema, if(valuesIndex + "VALUES".length() <= firstLeftBracketIndex) { throw new SQLSyntaxErrorException("insert must provide ColumnList"); } - List> vauleList = parseSqlValue(origSQL , valuesIndex); + Object[] vauleArrayAndSuffixStr = parseSqlValueArrayAndSuffixStr(origSQL , valuesIndex); + List> vauleArray = (List>) vauleArrayAndSuffixStr[0]; + String suffixStr = null; + if (vauleArrayAndSuffixStr.length > 1) { + suffixStr = (String) vauleArrayAndSuffixStr[1]; + } //两种情况处理 1 有主键的 id ,但是值为null 进行改下 // 2 没有主键的 需要插入 进行改写 @@ -653,7 +658,7 @@ public static boolean processInsert(ServerConnection sc,SchemaConfig schema, if(pkStart == -1){ processedInsert = true; - handleBatchInsert(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleList); + handleBatchInsert(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleArray, suffixStr); } else { //判断 主键id的值是否为null if(pkStart != -1) { @@ -666,7 +671,7 @@ public static boolean processInsert(ServerConnection sc,SchemaConfig schema, pkIndex ++; } } - processedInsert = handleBatchInsertWithPK(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleList , pkIndex); + processedInsert = handleBatchInsertWithPK(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleArray, suffixStr, pkIndex); } } return processedInsert; @@ -674,36 +679,38 @@ public static boolean processInsert(ServerConnection sc,SchemaConfig schema, private static boolean handleBatchInsertWithPK(ServerConnection sc, SchemaConfig schema, int sqlType, String origSQL, int valuesIndex, String tableName, String primaryKey, List> vauleList, - int pkIndex) { - boolean processedInsert = false; + String suffixStr, int pkIndex) { + boolean processedInsert = false; // final String pk = "\\("+primaryKey+","; - final String mycatSeqPrefix = "next value for MYCATSEQ_"+tableName.toUpperCase() ; + final String mycatSeqPrefix = "next value for MYCATSEQ_"+tableName.toUpperCase() ; /*"VALUES".length() ==6 */ - String prefix = origSQL.substring(0, valuesIndex + 6); + String prefix = origSQL.substring(0, valuesIndex + 6); // - StringBuilder sb = new StringBuilder(""); - for(List list : vauleList) { - sb.append("("); - String pkValue = list.get(pkIndex).trim().toLowerCase(); - //null值替换为 next value for MYCATSEQ_tableName - if("null".equals(pkValue.trim())) { - list.set(pkIndex, mycatSeqPrefix); - processedInsert = true; - } - for(String val : list) { - sb.append(val).append(","); - } - sb.setCharAt(sb.length() - 1, ')'); - sb.append(","); - } - sb.setCharAt(sb.length() - 1, ' ');; - if(processedInsert) { - processSQL(sc, schema,prefix+sb.toString(), sqlType); - - } - return processedInsert; + StringBuilder sb = new StringBuilder(""); + for(List list : vauleList) { + sb.append("("); + String pkValue = list.get(pkIndex).trim().toLowerCase(); + //null值替换为 next value for MYCATSEQ_tableName + if("null".equals(pkValue.trim())) { + list.set(pkIndex, mycatSeqPrefix); + processedInsert = true; + } + for(String val : list) { + sb.append(val).append(","); + } + sb.setCharAt(sb.length() - 1, ')'); + sb.append(","); + } + sb.setCharAt(sb.length() - 1, ' '); + if (suffixStr != null) { + sb.append(suffixStr); + } + if(processedInsert) { + processSQL(sc, schema,prefix+sb.toString(), sqlType); + } + return processedInsert; } public static List handleBatchInsert(String origSQL, int valuesIndex) { @@ -765,40 +772,49 @@ public static List handleBatchInsert(String origSQL, int valuesIndex) { * [")", "\"\')"], * [ 1, null] * 值结果的解析 - */ - public static List> parseSqlValue(String origSQL,int valuesIndex ) { + */ + public static Object[] parseSqlValueArrayAndSuffixStr(String origSQL, int valuesIndex) { List> valueArray = new ArrayList<>(); - String valueStr = origSQL.substring(valuesIndex + 6);// 6 values 长度为6 - String preStr = origSQL.substring(0, valuesIndex );// 6 values 长度为6 + String valuesAndSuffixStr = origSQL.substring(valuesIndex + 6);// 6 values 长度为6 int pos = 0 ; int flag = 4; - int len = valueStr.length(); + int len = valuesAndSuffixStr.length(); StringBuilder currentValue = new StringBuilder(); // int colNum = 2; // char c ; List curList = new ArrayList<>(); int parenCount = 0; for( ;pos < len; pos ++) { - c = valueStr.charAt(pos); - if(flag == 1 || flag == 2) { + c = valuesAndSuffixStr.charAt(pos); + if (flag == 1 || flag == 2) { currentValue.append(c); - if(c == '\\') { - char nextCode = valueStr.charAt(pos + 1); - if(nextCode == '\'' || nextCode == '\"') { + if (c == '\\') { + char nextCode = valuesAndSuffixStr.charAt(pos + 1); + if (nextCode == '\'' || nextCode == '\"') { currentValue.append(nextCode); pos++; continue; } } - if(c == '\"' && flag == 1) { + if (c == '\"' && flag == 1) { flag = 0; continue; } - if(c == '\'' && flag == 2) { + if (c == '\'' && flag == 2) { flag = 0; continue; } - } else if(c == '\"'){ + } else if (flag == 5) { + currentValue.append(c); + if (c == '(') { + parenCount++; + } else if (c == ')') { + parenCount--; + } + if (parenCount == 0) { + flag = 0; + } + } else if (c == '\"'){ currentValue.append(c); flag = 1; } else if (c == '\'') { @@ -810,26 +826,20 @@ public static List> parseSqlValue(String origSQL,int valuesIndex ) flag = 0; } else { currentValue.append(c); - flag = 6; + flag = 5; parenCount++; } } else if (flag == 4) { + if (c == 'o') { + String suffixStr = valuesAndSuffixStr.substring(pos); + return new Object[]{valueArray, suffixStr}; + } continue; - } else if (flag == 6) { - currentValue.append(c); - if (c == '(') { - parenCount++; - } else if (c == ')') { - parenCount--; - } - if (parenCount == 0) { - flag = 0; - } - } else if(c == ',') { + } else if (c == ',') { // System.out.println(currentValue); curList.add(currentValue.toString()); currentValue.delete(0, currentValue.length()); - } else if(c == ')'){ + } else if (c == ')'){ flag = 4; // System.out.println(currentValue); curList.add(currentValue.toString()); @@ -839,17 +849,16 @@ public static List> parseSqlValue(String origSQL,int valuesIndex ) currentValue.append(c); } } - return valueArray; - } - - /** + return new Object[]{valueArray}; + } + /** * 对于主键不在插入语句的fields中的SQL,需要改写。比如hotnews主键为id,插入语句为: * insert into hotnews(title) values('aaa'); * 需要改写成: * insert into hotnews(id, title) values(next value for MYCATSEQ_hotnews,'aaa'); */ public static void handleBatchInsert(ServerConnection sc, SchemaConfig schema, - int sqlType,String origSQL, int valuesIndex,String tableName, String primaryKey , List> vauleList) { + int sqlType,String origSQL, int valuesIndex,String tableName, String primaryKey , List> vauleList, String suffixStr) { final String pk = "\\("+primaryKey+","; final String mycatSeqPrefix = "(next value for MYCATSEQ_"+tableName.toUpperCase()+""; @@ -867,7 +876,10 @@ public static void handleBatchInsert(ServerConnection sc, SchemaConfig schema, } sb.append("),"); } - sb.setCharAt(sb.length() - 1, ' ');; + sb.setCharAt(sb.length() - 1, ' '); + if (suffixStr != null) { + sb.append(suffixStr); + } processSQL(sc, schema,prefix+sb.toString(), sqlType); } // /** diff --git a/src/test/java/io/mycat/route/util/RouterUtilTest.java b/src/test/java/io/mycat/route/util/RouterUtilTest.java index 6c9357aa1..64ca03db1 100644 --- a/src/test/java/io/mycat/route/util/RouterUtilTest.java +++ b/src/test/java/io/mycat/route/util/RouterUtilTest.java @@ -6,6 +6,8 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -31,8 +33,23 @@ public void testBatchInsert() { Assert.assertTrue(values.get(2).equals("insert into hotnews(title,name) values('\\\"',\"\\'\")")); Assert.assertTrue(values.get(3).equals("insert into hotnews(title,name) values(\")\",\"\\\"\\')\")")); } - - + + @Test + public void testParseSqlValueArrayAndSuffixStr() { + String sql = "insert into hotnews(title,name) values('test1',\"name\"),('(test)',\"(test)\"),('\\\"',\"\\'\"),(\")\",\"\\\"\\')\"),(left(upper('test'), 2),\"left(upper('test'), 2)\") on duplicate key update name = values(name)"; + Object[] valueArrayAndSuffixStr = RouterUtil.parseSqlValueArrayAndSuffixStr(sql, sql.toUpperCase().indexOf("VALUES")); + Assert.assertTrue(valueArrayAndSuffixStr.length == 2); + List> valueArray = (List>) valueArrayAndSuffixStr[0]; + String suffixStr = (String) valueArrayAndSuffixStr[1]; + Assert.assertTrue(valueArray.size() == 5); + Assert.assertTrue(valueArray.get(0).equals(new ArrayList(Arrays.asList("'test1'", "\"name\"")))); + Assert.assertTrue(valueArray.get(1).equals(new ArrayList(Arrays.asList("'(test)'", "\"(test)\"")))); + Assert.assertTrue(valueArray.get(2).equals(new ArrayList(Arrays.asList("'\\\"'", "\"\\'\"")))); + Assert.assertTrue(valueArray.get(3).equals(new ArrayList(Arrays.asList("\")\"", "\"\\\"\\')\"")))); + Assert.assertTrue(valueArray.get(4).equals(new ArrayList(Arrays.asList("left(upper('test'), 2)", "\"left(upper('test'), 2)\"")))); + Assert.assertTrue(suffixStr.equals("on duplicate key update name = values(name)")); + } + @Test public void testRemoveSchema() { String sql = "update test set name='abcdtestx.aa' where id=1 and testx=123";