Skip to content

Commit

Permalink
修复使用mycat自增序列插入包含on duplicate key update子句时bug
Browse files Browse the repository at this point in the history
  • Loading branch information
maxiaoguang64 committed Apr 15, 2019
1 parent 1d4d4e5 commit fa3f320
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 62 deletions.
132 changes: 72 additions & 60 deletions src/main/java/io/mycat/route/util/RouterUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,12 @@ public static boolean processInsert(ServerConnection sc,SchemaConfig schema,
if(valuesIndex + "VALUES".length() <= firstLeftBracketIndex) {
throw new SQLSyntaxErrorException("insert must provide ColumnList");
}
List<List<String>> vauleList = parseSqlValue(origSQL , valuesIndex);
Object[] vauleArrayAndSuffixStr = parseSqlValueArrayAndSuffixStr(origSQL , valuesIndex);
List<List<String>> vauleArray = (List<List<String>>) vauleArrayAndSuffixStr[0];
String suffixStr = null;
if (vauleArrayAndSuffixStr.length > 1) {
suffixStr = (String) vauleArrayAndSuffixStr[1];
}
//两种情况处理 1 有主键的 id ,但是值为null 进行改下
// 2 没有主键的 需要插入 进行改写

Expand All @@ -653,7 +658,7 @@ public static boolean processInsert(ServerConnection sc,SchemaConfig schema,

if(pkStart == -1){
processedInsert = true;
handleBatchInsert(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleList);
handleBatchInsert(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleArray, suffixStr);
} else {
//判断 主键id的值是否为null
if(pkStart != -1) {
Expand All @@ -666,44 +671,46 @@ public static boolean processInsert(ServerConnection sc,SchemaConfig schema,
pkIndex ++;
}
}
processedInsert = handleBatchInsertWithPK(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleList , pkIndex);
processedInsert = handleBatchInsertWithPK(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleArray, suffixStr, pkIndex);
}
}
return processedInsert;
}

private static boolean handleBatchInsertWithPK(ServerConnection sc, SchemaConfig schema, int sqlType,
String origSQL, int valuesIndex, String tableName, String primaryKey, List<List<String>> vauleList,
int pkIndex) {
boolean processedInsert = false;
String suffixStr, int pkIndex) {
boolean processedInsert = false;
// final String pk = "\\("+primaryKey+",";
final String mycatSeqPrefix = "next value for MYCATSEQ_"+tableName.toUpperCase() ;
final String mycatSeqPrefix = "next value for MYCATSEQ_"+tableName.toUpperCase() ;

/*"VALUES".length() ==6 */
String prefix = origSQL.substring(0, valuesIndex + 6);
String prefix = origSQL.substring(0, valuesIndex + 6);
//

StringBuilder sb = new StringBuilder("");
for(List<String> list : vauleList) {
sb.append("(");
String pkValue = list.get(pkIndex).trim().toLowerCase();
//null值替换为 next value for MYCATSEQ_tableName
if("null".equals(pkValue.trim())) {
list.set(pkIndex, mycatSeqPrefix);
processedInsert = true;
}
for(String val : list) {
sb.append(val).append(",");
}
sb.setCharAt(sb.length() - 1, ')');
sb.append(",");
}
sb.setCharAt(sb.length() - 1, ' ');;
if(processedInsert) {
processSQL(sc, schema,prefix+sb.toString(), sqlType);

}
return processedInsert;
StringBuilder sb = new StringBuilder("");
for(List<String> list : vauleList) {
sb.append("(");
String pkValue = list.get(pkIndex).trim().toLowerCase();
//null值替换为 next value for MYCATSEQ_tableName
if("null".equals(pkValue.trim())) {
list.set(pkIndex, mycatSeqPrefix);
processedInsert = true;
}
for(String val : list) {
sb.append(val).append(",");
}
sb.setCharAt(sb.length() - 1, ')');
sb.append(",");
}
sb.setCharAt(sb.length() - 1, ' ');
if (suffixStr != null) {
sb.append(suffixStr);
}
if(processedInsert) {
processSQL(sc, schema,prefix+sb.toString(), sqlType);
}
return processedInsert;
}

public static List<String> handleBatchInsert(String origSQL, int valuesIndex) {
Expand Down Expand Up @@ -765,40 +772,49 @@ public static List<String> handleBatchInsert(String origSQL, int valuesIndex) {
* [")", "\"\')"],
* [ 1, null]
* 值结果的解析
*/
public static List<List<String>> parseSqlValue(String origSQL,int valuesIndex ) {
*/
public static Object[] parseSqlValueArrayAndSuffixStr(String origSQL, int valuesIndex) {
List<List<String>> valueArray = new ArrayList<>();
String valueStr = origSQL.substring(valuesIndex + 6);// 6 values 长度为6
String preStr = origSQL.substring(0, valuesIndex );// 6 values 长度为6
String valuesAndSuffixStr = origSQL.substring(valuesIndex + 6);// 6 values 长度为6
int pos = 0 ;
int flag = 4;
int len = valueStr.length();
int len = valuesAndSuffixStr.length();
StringBuilder currentValue = new StringBuilder();
// int colNum = 2; //
char c ;
List<String> curList = new ArrayList<>();
int parenCount = 0;
for( ;pos < len; pos ++) {
c = valueStr.charAt(pos);
if(flag == 1 || flag == 2) {
c = valuesAndSuffixStr.charAt(pos);
if (flag == 1 || flag == 2) {
currentValue.append(c);
if(c == '\\') {
char nextCode = valueStr.charAt(pos + 1);
if(nextCode == '\'' || nextCode == '\"') {
if (c == '\\') {
char nextCode = valuesAndSuffixStr.charAt(pos + 1);
if (nextCode == '\'' || nextCode == '\"') {
currentValue.append(nextCode);
pos++;
continue;
}
}
if(c == '\"' && flag == 1) {
if (c == '\"' && flag == 1) {
flag = 0;
continue;
}
if(c == '\'' && flag == 2) {
if (c == '\'' && flag == 2) {
flag = 0;
continue;
}
} else if(c == '\"'){
} else if (flag == 5) {
currentValue.append(c);
if (c == '(') {
parenCount++;
} else if (c == ')') {
parenCount--;
}
if (parenCount == 0) {
flag = 0;
}
} else if (c == '\"'){
currentValue.append(c);
flag = 1;
} else if (c == '\'') {
Expand All @@ -810,26 +826,20 @@ public static List<List<String>> parseSqlValue(String origSQL,int valuesIndex )
flag = 0;
} else {
currentValue.append(c);
flag = 6;
flag = 5;
parenCount++;
}
} else if (flag == 4) {
if (c == 'o') {
String suffixStr = valuesAndSuffixStr.substring(pos);
return new Object[]{valueArray, suffixStr};
}
continue;
} else if (flag == 6) {
currentValue.append(c);
if (c == '(') {
parenCount++;
} else if (c == ')') {
parenCount--;
}
if (parenCount == 0) {
flag = 0;
}
} else if(c == ',') {
} else if (c == ',') {
// System.out.println(currentValue);
curList.add(currentValue.toString());
currentValue.delete(0, currentValue.length());
} else if(c == ')'){
} else if (c == ')'){
flag = 4;
// System.out.println(currentValue);
curList.add(currentValue.toString());
Expand All @@ -839,17 +849,16 @@ public static List<List<String>> parseSqlValue(String origSQL,int valuesIndex )
currentValue.append(c);
}
}
return valueArray;
}

/**
return new Object[]{valueArray};
}
/**
* 对于主键不在插入语句的fields中的SQL,需要改写。比如hotnews主键为id,插入语句为:
* insert into hotnews(title) values('aaa');
* 需要改写成:
* insert into hotnews(id, title) values(next value for MYCATSEQ_hotnews,'aaa');
*/
public static void handleBatchInsert(ServerConnection sc, SchemaConfig schema,
int sqlType,String origSQL, int valuesIndex,String tableName, String primaryKey , List<List<String>> vauleList) {
int sqlType,String origSQL, int valuesIndex,String tableName, String primaryKey , List<List<String>> vauleList, String suffixStr) {

final String pk = "\\("+primaryKey+",";
final String mycatSeqPrefix = "(next value for MYCATSEQ_"+tableName.toUpperCase()+"";
Expand All @@ -867,7 +876,10 @@ public static void handleBatchInsert(ServerConnection sc, SchemaConfig schema,
}
sb.append("),");
}
sb.setCharAt(sb.length() - 1, ' ');;
sb.setCharAt(sb.length() - 1, ' ');
if (suffixStr != null) {
sb.append(suffixStr);
}
processSQL(sc, schema,prefix+sb.toString(), sqlType);
}
// /**
Expand Down
21 changes: 19 additions & 2 deletions src/test/java/io/mycat/route/util/RouterUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

Expand All @@ -31,8 +33,23 @@ public void testBatchInsert() {
Assert.assertTrue(values.get(2).equals("insert into hotnews(title,name) values('\\\"',\"\\'\")"));
Assert.assertTrue(values.get(3).equals("insert into hotnews(title,name) values(\")\",\"\\\"\\')\")"));
}



@Test
public void testParseSqlValueArrayAndSuffixStr() {
String sql = "insert into hotnews(title,name) values('test1',\"name\"),('(test)',\"(test)\"),('\\\"',\"\\'\"),(\")\",\"\\\"\\')\"),(left(upper('test'), 2),\"left(upper('test'), 2)\") on duplicate key update name = values(name)";
Object[] valueArrayAndSuffixStr = RouterUtil.parseSqlValueArrayAndSuffixStr(sql, sql.toUpperCase().indexOf("VALUES"));
Assert.assertTrue(valueArrayAndSuffixStr.length == 2);
List<List<String>> valueArray = (List<List<String>>) valueArrayAndSuffixStr[0];
String suffixStr = (String) valueArrayAndSuffixStr[1];
Assert.assertTrue(valueArray.size() == 5);
Assert.assertTrue(valueArray.get(0).equals(new ArrayList<String>(Arrays.asList("'test1'", "\"name\""))));
Assert.assertTrue(valueArray.get(1).equals(new ArrayList<String>(Arrays.asList("'(test)'", "\"(test)\""))));
Assert.assertTrue(valueArray.get(2).equals(new ArrayList<String>(Arrays.asList("'\\\"'", "\"\\'\""))));
Assert.assertTrue(valueArray.get(3).equals(new ArrayList<String>(Arrays.asList("\")\"", "\"\\\"\\')\""))));
Assert.assertTrue(valueArray.get(4).equals(new ArrayList<String>(Arrays.asList("left(upper('test'), 2)", "\"left(upper('test'), 2)\""))));
Assert.assertTrue(suffixStr.equals("on duplicate key update name = values(name)"));
}

@Test
public void testRemoveSchema() {
String sql = "update test set name='abcdtestx.aa' where id=1 and testx=123";
Expand Down

0 comments on commit fa3f320

Please sign in to comment.