diff --git a/src/main/java/io/mycat/backend/mysql/nio/MySQLConnection.java b/src/main/java/io/mycat/backend/mysql/nio/MySQLConnection.java index d83a58c63..813523ed8 100644 --- a/src/main/java/io/mycat/backend/mysql/nio/MySQLConnection.java +++ b/src/main/java/io/mycat/backend/mysql/nio/MySQLConnection.java @@ -1,709 +1,707 @@ -/* - * Copyright (c) 2013, OpenCloudDB/MyCAT and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software;Designed and Developed mainly by many Chinese - * opensource volunteers. you can redistribute it and/or modify it under the - * terms of the GNU General Public License version 2 only, as published by the - * Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Any questions about this component can be directed to it's project Web address - * https://code.google.com/p/opencloudb/. - * - */ -package io.mycat.backend.mysql.nio; - -import io.mycat.backend.mysql.xa.TxState; -import org.slf4j.Logger; import org.slf4j.LoggerFactory; - -import io.mycat.MycatServer; -import io.mycat.backend.mysql.CharsetUtil; -import io.mycat.backend.mysql.SecurityUtil; -import io.mycat.backend.mysql.nio.handler.ResponseHandler; -import io.mycat.config.Capabilities; -import io.mycat.config.Isolations; -import io.mycat.net.BackendAIOConnection; -import io.mycat.net.mysql.*; -import io.mycat.route.RouteResultsetNode; -import io.mycat.server.ServerConnection; -import io.mycat.server.parser.ServerParse; -import io.mycat.util.TimeUtil; -import io.mycat.util.exception.UnknownTxIsolationException; - -import java.io.UnsupportedEncodingException; -import java.nio.channels.NetworkChannel; -import java.security.NoSuchAlgorithmException; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; - -/** - * @author mycat - */ -public class MySQLConnection extends BackendAIOConnection { - private static final Logger LOGGER = LoggerFactory - .getLogger(MySQLConnection.class); - private static final long CLIENT_FLAGS = initClientFlags(); - private volatile long lastTime; - private volatile String schema = null; - private volatile String oldSchema; - private volatile boolean borrowed = false; - private volatile boolean modifiedSQLExecuted = false; - private volatile int batchCmdCount = 0; - - private static long initClientFlags() { - int flag = 0; - flag |= Capabilities.CLIENT_LONG_PASSWORD; - flag |= Capabilities.CLIENT_FOUND_ROWS; - flag |= Capabilities.CLIENT_LONG_FLAG; - flag |= Capabilities.CLIENT_CONNECT_WITH_DB; - // flag |= Capabilities.CLIENT_NO_SCHEMA; - boolean usingCompress=MycatServer.getInstance().getConfig().getSystem().getUseCompression()==1 ; - if(usingCompress) - { - flag |= Capabilities.CLIENT_COMPRESS; - } - flag |= Capabilities.CLIENT_ODBC; - flag |= Capabilities.CLIENT_LOCAL_FILES; - flag |= Capabilities.CLIENT_IGNORE_SPACE; - flag |= Capabilities.CLIENT_PROTOCOL_41; - flag |= Capabilities.CLIENT_INTERACTIVE; - // flag |= Capabilities.CLIENT_SSL; - flag |= Capabilities.CLIENT_IGNORE_SIGPIPE; - flag |= Capabilities.CLIENT_TRANSACTIONS; - // flag |= Capabilities.CLIENT_RESERVED; - flag |= Capabilities.CLIENT_SECURE_CONNECTION; - // client extension - flag |= Capabilities.CLIENT_MULTI_STATEMENTS; - flag |= Capabilities.CLIENT_MULTI_RESULTS; - return flag; - } - - private static final CommandPacket _READ_UNCOMMITTED = new CommandPacket(); - private static final CommandPacket _READ_COMMITTED = new CommandPacket(); - private static final CommandPacket _REPEATED_READ = new CommandPacket(); - private static final CommandPacket _SERIALIZABLE = new CommandPacket(); - private static final CommandPacket _AUTOCOMMIT_ON = new CommandPacket(); - private static final CommandPacket _AUTOCOMMIT_OFF = new CommandPacket(); - private static final CommandPacket _COMMIT = new CommandPacket(); - private static final CommandPacket _ROLLBACK = new CommandPacket(); - static { - _READ_UNCOMMITTED.packetId = 0; - _READ_UNCOMMITTED.command = MySQLPacket.COM_QUERY; - _READ_UNCOMMITTED.arg = "SET SESSION TRANSACTION ISOLATION LEVEL READ UNCOMMITTED" - .getBytes(); - _READ_COMMITTED.packetId = 0; - _READ_COMMITTED.command = MySQLPacket.COM_QUERY; - _READ_COMMITTED.arg = "SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED" - .getBytes(); - _REPEATED_READ.packetId = 0; - _REPEATED_READ.command = MySQLPacket.COM_QUERY; - _REPEATED_READ.arg = "SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ" - .getBytes(); - _SERIALIZABLE.packetId = 0; - _SERIALIZABLE.command = MySQLPacket.COM_QUERY; - _SERIALIZABLE.arg = "SET SESSION TRANSACTION ISOLATION LEVEL SERIALIZABLE" - .getBytes(); - _AUTOCOMMIT_ON.packetId = 0; - _AUTOCOMMIT_ON.command = MySQLPacket.COM_QUERY; - _AUTOCOMMIT_ON.arg = "SET autocommit=1".getBytes(); - _AUTOCOMMIT_OFF.packetId = 0; - _AUTOCOMMIT_OFF.command = MySQLPacket.COM_QUERY; - _AUTOCOMMIT_OFF.arg = "SET autocommit=0".getBytes(); - _COMMIT.packetId = 0; - _COMMIT.command = MySQLPacket.COM_QUERY; - _COMMIT.arg = "commit".getBytes(); - _ROLLBACK.packetId = 0; - _ROLLBACK.command = MySQLPacket.COM_QUERY; - _ROLLBACK.arg = "rollback".getBytes(); - } - - private MySQLDataSource pool; - private boolean fromSlaveDB; - private long threadId; - private HandshakePacket handshake; - private volatile int txIsolation; - private volatile boolean autocommit; - private long clientFlags; - private boolean isAuthenticated; - private String user; - private String password; - private Object attachment; - private volatile ResponseHandler respHandler; - - private final AtomicBoolean isQuit; - private volatile StatusSync statusSync; - private volatile boolean metaDataSyned = true; - private volatile int xaStatus = 0; - - public MySQLConnection(NetworkChannel channel, boolean fromSlaveDB) { - super(channel); - this.clientFlags = CLIENT_FLAGS; - this.lastTime = TimeUtil.currentTimeMillis(); - this.isQuit = new AtomicBoolean(false); - this.autocommit = true; - this.fromSlaveDB = fromSlaveDB; - // 设为默认值,免得每个初始化好的连接都要去同步一下 - this.txIsolation = MycatServer.getInstance().getConfig().getSystem().getTxIsolation(); - } - - public int getXaStatus() { - return xaStatus; - } - - public void setXaStatus(int xaStatus) { - this.xaStatus = xaStatus; - } - - public void onConnectFailed(Throwable t) { - if (handler instanceof MySQLConnectionHandler) { - MySQLConnectionHandler theHandler = (MySQLConnectionHandler) handler; - theHandler.connectionError(t); - } else { - ((MySQLConnectionAuthenticator) handler).connectionError(this, t); - } - } - - public String getSchema() { - return this.schema; - } - - public void setSchema(String newSchema) { - String curSchema = schema; - if (curSchema == null) { - this.schema = newSchema; - this.oldSchema = newSchema; - } else { - this.oldSchema = curSchema; - this.schema = newSchema; - } - } - - public MySQLDataSource getPool() { - return pool; - } - - public void setPool(MySQLDataSource pool) { - this.pool = pool; - } - - public String getUser() { - return user; - } - - public void setUser(String user) { - this.user = user; - } - - public void setPassword(String password) { - this.password = password; - } - - public HandshakePacket getHandshake() { - return handshake; - } - - public void setHandshake(HandshakePacket handshake) { - this.handshake = handshake; - } - - public long getThreadId() { - return threadId; - } - - public void setThreadId(long threadId) { - this.threadId = threadId; - } - - public boolean isAuthenticated() { - return isAuthenticated; - } - - public void setAuthenticated(boolean isAuthenticated) { - this.isAuthenticated = isAuthenticated; - } - - public String getPassword() { - return password; - } - - public void authenticate() { - AuthPacket packet = new AuthPacket(); - packet.packetId = 1; - packet.clientFlags = clientFlags; - packet.maxPacketSize = maxPacketSize; - packet.charsetIndex = this.charsetIndex; - packet.user = user; - try { - packet.password = passwd(password, handshake); - } catch (NoSuchAlgorithmException e) { - throw new RuntimeException(e.getMessage()); - } - packet.database = schema; - packet.write(this); - } - - public boolean isAutocommit() { - return autocommit; - } - - public Object getAttachment() { - return attachment; - } - - public void setAttachment(Object attachment) { - this.attachment = attachment; - } - - public boolean isClosedOrQuit() { - return isClosed() || isQuit.get(); - } - - protected void sendQueryCmd(String query) { - CommandPacket packet = new CommandPacket(); - packet.packetId = 0; - packet.command = MySQLPacket.COM_QUERY; - try { - packet.arg = query.getBytes(charset); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException(e); - } - lastTime = TimeUtil.currentTimeMillis(); - packet.write(this); - } - - private static void getCharsetCommand(StringBuilder sb, int clientCharIndex) { - sb.append("SET names ").append(CharsetUtil.getCharset(clientCharIndex)) - .append(";"); - } - - private static void getTxIsolationCommand(StringBuilder sb, int txIsolation) { - switch (txIsolation) { - case Isolations.READ_UNCOMMITTED: - sb.append("SET SESSION TRANSACTION ISOLATION LEVEL READ UNCOMMITTED;"); - return; - case Isolations.READ_COMMITTED: - sb.append("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED;"); - return; - case Isolations.REPEATED_READ: - sb.append("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); - return; - case Isolations.SERIALIZABLE: - sb.append("SET SESSION TRANSACTION ISOLATION LEVEL SERIALIZABLE;"); - return; - default: - throw new UnknownTxIsolationException("txIsolation:" + txIsolation); - } - } - - private void getAutocommitCommand(StringBuilder sb, boolean autoCommit) { - if (autoCommit) { - sb.append("SET autocommit=1;"); - } else { - sb.append("SET autocommit=0;"); - } - } - - private static class StatusSync { - private final String schema; - private final Integer charsetIndex; - private final Integer txtIsolation; - private final Boolean autocommit; - private final AtomicInteger synCmdCount; - private final boolean xaStarted; - - public StatusSync(boolean xaStarted, String schema, - Integer charsetIndex, Integer txtIsolation, Boolean autocommit, - int synCount) { - super(); - this.xaStarted = xaStarted; - this.schema = schema; - this.charsetIndex = charsetIndex; - this.txtIsolation = txtIsolation; - this.autocommit = autocommit; - this.synCmdCount = new AtomicInteger(synCount); - } - - public boolean synAndExecuted(MySQLConnection conn) { - int remains = synCmdCount.decrementAndGet(); - if (remains == 0) {// syn command finished - this.updateConnectionInfo(conn); - conn.metaDataSyned = true; - return false; - } else if (remains < 0) { - return true; - } - return false; - } - - private void updateConnectionInfo(MySQLConnection conn) - - { - if (schema != null) { - conn.schema = schema; - conn.oldSchema = conn.schema; - } - if (charsetIndex != null) { - conn.setCharset(CharsetUtil.getCharset(charsetIndex)); - } - if (txtIsolation != null) { - conn.txIsolation = txtIsolation; - } - if (autocommit != null) { - conn.autocommit = autocommit; - } - } - - } - - /** - * @return if synchronization finished and execute-sql has already been sent - * before - */ - public boolean syncAndExcute() { - StatusSync sync = this.statusSync; - if (sync == null) { - return true; - } else { - boolean executed = sync.synAndExecuted(this); - if (executed) { - statusSync = null; - } - return executed; - } - - } - - public void execute(RouteResultsetNode rrn, ServerConnection sc, - boolean autocommit) throws UnsupportedEncodingException { - if (!modifiedSQLExecuted && rrn.isModifySQL()) { - modifiedSQLExecuted = true; - } - String xaTXID = null; - if(sc.getSession2().getXaTXID()!=null){ - xaTXID = sc.getSession2().getXaTXID()+",'"+getSchema()+"'"; - } - synAndDoExecute(xaTXID, rrn, sc.getCharsetIndex(), sc.getTxIsolation(), - autocommit); - } - - private void synAndDoExecute(String xaTxID, RouteResultsetNode rrn, - int clientCharSetIndex, int clientTxIsoLation, - boolean clientAutoCommit) { - String xaCmd = null; - - boolean conAutoComit = this.autocommit; - String conSchema = this.schema; - boolean strictTxIsolation = MycatServer.getInstance().getConfig().getSystem().isStrictTxIsolation(); - boolean expectAutocommit = false; - // 如果在非自动提交情况下,如果需要严格保证事务级别,则需做下列判断 - if (strictTxIsolation) { - expectAutocommit = isFromSlaveDB() || clientAutoCommit; - } else { - // never executed modify sql,so auto commit - expectAutocommit = (!modifiedSQLExecuted || isFromSlaveDB() || clientAutoCommit); - } - expectAutocommit = true; - - if (expectAutocommit == false && xaTxID != null && xaStatus == TxState.TX_INITIALIZE_STATE) { - //clientTxIsoLation = Isolations.SERIALIZABLE; - xaCmd = "XA START " + xaTxID + ';'; - this.xaStatus = TxState.TX_STARTED_STATE; - } - int schemaSyn = conSchema.equals(oldSchema) ? 0 : 1; - int charsetSyn = 0; - if (this.charsetIndex != clientCharSetIndex) { - //need to syn the charset of connection. - //set current connection charset to client charset. - //otherwise while sending commend to server the charset will not coincidence. - setCharset(CharsetUtil.getCharset(clientCharSetIndex)); - charsetSyn = 1; - } - int txIsoLationSyn = (txIsolation == clientTxIsoLation) ? 0 : 1; - int autoCommitSyn = (conAutoComit == expectAutocommit) ? 0 : 1; - int synCount = schemaSyn + charsetSyn + txIsoLationSyn + autoCommitSyn + (xaCmd!=null?1:0); -// if (synCount == 0 && this.xaStatus != TxState.TX_STARTED_STATE) { - if (synCount == 0 ) { - // not need syn connection - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("not need syn connection :\n" + this+"\n to send query cmd:\n"+rrn.getStatement() - +"\n in pool\n" - +this.getPool().getConfig()); - } - sendQueryCmd(rrn.getStatement()); - return; - } - CommandPacket schemaCmd = null; - StringBuilder sb = new StringBuilder(); - if (schemaSyn == 1) { - schemaCmd = getChangeSchemaCommand(conSchema); - // getChangeSchemaCommand(sb, conSchema); - } - - if (charsetSyn == 1) { - getCharsetCommand(sb, clientCharSetIndex); - } - if (txIsoLationSyn == 1) { - getTxIsolationCommand(sb, clientTxIsoLation); - } - if (autoCommitSyn == 1) { - getAutocommitCommand(sb, expectAutocommit); - } - if (xaCmd != null) { - sb.append(xaCmd); - } - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("con need syn ,total syn cmd " + synCount - + " commands " + sb.toString() + "schema change:" - + (schemaCmd != null) + " con:" + this); - } - metaDataSyned = false; - statusSync = new StatusSync(xaCmd != null, conSchema, - clientCharSetIndex, clientTxIsoLation, expectAutocommit, - synCount); - // syn schema - if (schemaCmd != null) { - schemaCmd.write(this); - } - // and our query sql to multi command at last - sb.append(rrn.getStatement()+";"); - // syn and execute others - this.sendQueryCmd(sb.toString()); - // waiting syn result... - - } - - private static CommandPacket getChangeSchemaCommand(String schema) { - CommandPacket cmd = new CommandPacket(); - cmd.packetId = 0; - cmd.command = MySQLPacket.COM_INIT_DB; - cmd.arg = schema.getBytes(); - return cmd; - } - - /** - * by wuzh ,execute a query and ignore transaction settings for performance - * - * @param query - * @throws UnsupportedEncodingException - */ - public void query(String query) throws UnsupportedEncodingException { - RouteResultsetNode rrn = new RouteResultsetNode("default", - ServerParse.SELECT, query); - - synAndDoExecute(null, rrn, this.charsetIndex, this.txIsolation, true); - - } - /** - * by zwy ,execute a query with charsetIndex - * - * @param query - * @throws UnsupportedEncodingException - */ - @Override - public void query(String query, int charsetIndex) { - RouteResultsetNode rrn = new RouteResultsetNode("default", - ServerParse.SELECT, query); - - synAndDoExecute(null, rrn, charsetIndex, this.txIsolation, true); - - } - public long getLastTime() { - return lastTime; - } - - public void setLastTime(long lastTime) { - this.lastTime = lastTime; - } - - public void quit() { - if (isQuit.compareAndSet(false, true) && !isClosed()) { - if (isAuthenticated) { - write(writeToBuffer(QuitPacket.QUIT, allocate())); - write(allocate()); - } else { - close("normal"); - } - } - } - - @Override - public void close(String reason) { - if (!isClosed.get()) { - isQuit.set(true); - ResponseHandler tmpRespHandlers= respHandler; - setResponseHandler(null); - super.close(reason); - pool.connectionClosed(this); - if (tmpRespHandlers != null) { - tmpRespHandlers.connectionClose(this, reason); - } - if( this.handler instanceof MySQLConnectionAuthenticator) { - ((MySQLConnectionAuthenticator) this.handler).connectionError(this, new Throwable(reason)); - - } - } else { - //主要起一个清理资源的作用 - super.close(reason); - } - } - - public void commit() { - - _COMMIT.write(this); - - } - - public boolean batchCmdFinished() { - batchCmdCount--; - return (batchCmdCount == 0); - } - - public void execCmd(String cmd) { - this.sendQueryCmd(cmd); - } - - public void execBatchCmd(String[] batchCmds) { - // "XA END "+xaID+";"+"XA PREPARE "+xaID - this.batchCmdCount = batchCmds.length; - StringBuilder sb = new StringBuilder(); - for (String sql : batchCmds) { - sb.append(sql).append(';'); - } - this.sendQueryCmd(sb.toString()); - } - - public void rollback() { - _ROLLBACK.write(this); - } - - public void release() { - if (metaDataSyned == false) {// indicate connection not normalfinished - // ,and - // we can't know it's syn status ,so - // close - // it - LOGGER.warn("can't sure connection syn result,so close it " + this); - this.respHandler = null; - this.close("syn status unkown "); - return; - } - metaDataSyned = true; - attachment = null; - statusSync = null; - modifiedSQLExecuted = false; - xaStatus = TxState.TX_INITIALIZE_STATE; - setResponseHandler(null); - pool.releaseChannel(this); - } - - public boolean setResponseHandler(ResponseHandler queryHandler) { - if (handler instanceof MySQLConnectionHandler) { - ((MySQLConnectionHandler) handler).setResponseHandler(queryHandler); - respHandler = queryHandler; - return true; - } else if (queryHandler != null) { - LOGGER.warn("set not MySQLConnectionHandler " - + queryHandler.getClass().getCanonicalName()); - } - return false; - } - - /** - * 写队列为空,可以继续写数据 - */ - public void writeQueueAvailable() { - if (respHandler != null) { - respHandler.writeQueueAvailable(); - } - } - - /** - * 记录sql执行信息 - */ - public void recordSql(String host, String schema, String stmt) { - // final long now = TimeUtil.currentTimeMillis(); - // if (now > this.lastTime) { - // // long time = now - this.lastTime; - // // SQLRecorder sqlRecorder = this.pool.getSqlRecorder(); - // // if (sqlRecorder.check(time)) { - // // SQLRecord recorder = new SQLRecord(); - // // recorder.host = host; - // // recorder.schema = schema; - // // recorder.statement = stmt; - // // recorder.startTime = lastTime; - // // recorder.executeTime = time; - // // recorder.dataNode = pool.getName(); - // // recorder.dataNodeIndex = pool.getIndex(); - // // sqlRecorder.add(recorder); - // // } - // } - // this.lastTime = now; - } - - private static byte[] passwd(String pass, HandshakePacket hs) - throws NoSuchAlgorithmException { - if (pass == null || pass.length() == 0) { - return null; - } - byte[] passwd = pass.getBytes(); - int sl1 = hs.seed.length; - int sl2 = hs.restOfScrambleBuff.length; - byte[] seed = new byte[sl1 + sl2]; - System.arraycopy(hs.seed, 0, seed, 0, sl1); - System.arraycopy(hs.restOfScrambleBuff, 0, seed, sl1, sl2); - return SecurityUtil.scramble411(passwd, seed); - } - - @Override - public boolean isFromSlaveDB() { - return fromSlaveDB; - } - - @Override - public boolean isBorrowed() { - return borrowed; - } - - @Override - public void setBorrowed(boolean borrowed) { - this.lastTime = TimeUtil.currentTimeMillis(); - this.borrowed = borrowed; - } - - @Override - public String toString() { - return "MySQLConnection [id=" + id + ", lastTime=" + lastTime - + ", user=" + user - + ", schema=" + schema + ", old shema=" + oldSchema - + ", borrowed=" + borrowed + ", fromSlaveDB=" + fromSlaveDB - + ", threadId=" + threadId + ", charset=" + charset - + ", txIsolation=" + txIsolation + ", autocommit=" + autocommit - + ", attachment=" + attachment + ", respHandler=" + respHandler - + ", host=" + host + ", port=" + port + ", statusSync=" - + statusSync + ", writeQueue=" + this.getWriteQueue().size() - + ", modifiedSQLExecuted=" + modifiedSQLExecuted + "]"; - } - - @Override - public boolean isModifiedSQLExecuted() { - return modifiedSQLExecuted; - } - - @Override - public int getTxIsolation() { - return txIsolation; - } - - - -} +/* + * Copyright (c) 2013, OpenCloudDB/MyCAT and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software;Designed and Developed mainly by many Chinese + * opensource volunteers. you can redistribute it and/or modify it under the + * terms of the GNU General Public License version 2 only, as published by the + * Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Any questions about this component can be directed to it's project Web address + * https://code.google.com/p/opencloudb/. + * + */ +package io.mycat.backend.mysql.nio; + +import io.mycat.backend.mysql.xa.TxState; +import org.slf4j.Logger; import org.slf4j.LoggerFactory; + +import io.mycat.MycatServer; +import io.mycat.backend.mysql.CharsetUtil; +import io.mycat.backend.mysql.SecurityUtil; +import io.mycat.backend.mysql.nio.handler.ResponseHandler; +import io.mycat.config.Capabilities; +import io.mycat.config.Isolations; +import io.mycat.net.BackendAIOConnection; +import io.mycat.net.mysql.*; +import io.mycat.route.RouteResultsetNode; +import io.mycat.server.ServerConnection; +import io.mycat.server.parser.ServerParse; +import io.mycat.util.TimeUtil; +import io.mycat.util.exception.UnknownTxIsolationException; + +import java.io.UnsupportedEncodingException; +import java.nio.channels.NetworkChannel; +import java.security.NoSuchAlgorithmException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * @author mycat + */ +public class MySQLConnection extends BackendAIOConnection { + private static final Logger LOGGER = LoggerFactory + .getLogger(MySQLConnection.class); + private static final long CLIENT_FLAGS = initClientFlags(); + private volatile long lastTime; + private volatile String schema = null; + private volatile String oldSchema; + private volatile boolean borrowed = false; + private volatile boolean modifiedSQLExecuted = false; + private volatile int batchCmdCount = 0; + + private static long initClientFlags() { + int flag = 0; + flag |= Capabilities.CLIENT_LONG_PASSWORD; + flag |= Capabilities.CLIENT_FOUND_ROWS; + flag |= Capabilities.CLIENT_LONG_FLAG; + flag |= Capabilities.CLIENT_CONNECT_WITH_DB; + // flag |= Capabilities.CLIENT_NO_SCHEMA; + boolean usingCompress=MycatServer.getInstance().getConfig().getSystem().getUseCompression()==1 ; + if(usingCompress) + { + flag |= Capabilities.CLIENT_COMPRESS; + } + flag |= Capabilities.CLIENT_ODBC; + flag |= Capabilities.CLIENT_LOCAL_FILES; + flag |= Capabilities.CLIENT_IGNORE_SPACE; + flag |= Capabilities.CLIENT_PROTOCOL_41; + flag |= Capabilities.CLIENT_INTERACTIVE; + // flag |= Capabilities.CLIENT_SSL; + flag |= Capabilities.CLIENT_IGNORE_SIGPIPE; + flag |= Capabilities.CLIENT_TRANSACTIONS; + // flag |= Capabilities.CLIENT_RESERVED; + flag |= Capabilities.CLIENT_SECURE_CONNECTION; + // client extension + flag |= Capabilities.CLIENT_MULTI_STATEMENTS; + flag |= Capabilities.CLIENT_MULTI_RESULTS; + return flag; + } + + private static final CommandPacket _READ_UNCOMMITTED = new CommandPacket(); + private static final CommandPacket _READ_COMMITTED = new CommandPacket(); + private static final CommandPacket _REPEATED_READ = new CommandPacket(); + private static final CommandPacket _SERIALIZABLE = new CommandPacket(); + private static final CommandPacket _AUTOCOMMIT_ON = new CommandPacket(); + private static final CommandPacket _AUTOCOMMIT_OFF = new CommandPacket(); + private static final CommandPacket _COMMIT = new CommandPacket(); + private static final CommandPacket _ROLLBACK = new CommandPacket(); + static { + _READ_UNCOMMITTED.packetId = 0; + _READ_UNCOMMITTED.command = MySQLPacket.COM_QUERY; + _READ_UNCOMMITTED.arg = "SET SESSION TRANSACTION ISOLATION LEVEL READ UNCOMMITTED" + .getBytes(); + _READ_COMMITTED.packetId = 0; + _READ_COMMITTED.command = MySQLPacket.COM_QUERY; + _READ_COMMITTED.arg = "SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED" + .getBytes(); + _REPEATED_READ.packetId = 0; + _REPEATED_READ.command = MySQLPacket.COM_QUERY; + _REPEATED_READ.arg = "SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ" + .getBytes(); + _SERIALIZABLE.packetId = 0; + _SERIALIZABLE.command = MySQLPacket.COM_QUERY; + _SERIALIZABLE.arg = "SET SESSION TRANSACTION ISOLATION LEVEL SERIALIZABLE" + .getBytes(); + _AUTOCOMMIT_ON.packetId = 0; + _AUTOCOMMIT_ON.command = MySQLPacket.COM_QUERY; + _AUTOCOMMIT_ON.arg = "SET autocommit=1".getBytes(); + _AUTOCOMMIT_OFF.packetId = 0; + _AUTOCOMMIT_OFF.command = MySQLPacket.COM_QUERY; + _AUTOCOMMIT_OFF.arg = "SET autocommit=0".getBytes(); + _COMMIT.packetId = 0; + _COMMIT.command = MySQLPacket.COM_QUERY; + _COMMIT.arg = "commit".getBytes(); + _ROLLBACK.packetId = 0; + _ROLLBACK.command = MySQLPacket.COM_QUERY; + _ROLLBACK.arg = "rollback".getBytes(); + } + + private MySQLDataSource pool; + private boolean fromSlaveDB; + private long threadId; + private HandshakePacket handshake; + private volatile int txIsolation; + private volatile boolean autocommit; + private long clientFlags; + private boolean isAuthenticated; + private String user; + private String password; + private Object attachment; + private volatile ResponseHandler respHandler; + + private final AtomicBoolean isQuit; + private volatile StatusSync statusSync; + private volatile boolean metaDataSyned = true; + private volatile int xaStatus = 0; + + public MySQLConnection(NetworkChannel channel, boolean fromSlaveDB) { + super(channel); + this.clientFlags = CLIENT_FLAGS; + this.lastTime = TimeUtil.currentTimeMillis(); + this.isQuit = new AtomicBoolean(false); + this.autocommit = true; + this.fromSlaveDB = fromSlaveDB; + // 设为默认值,免得每个初始化好的连接都要去同步一下 + this.txIsolation = MycatServer.getInstance().getConfig().getSystem().getTxIsolation(); + } + + public int getXaStatus() { + return xaStatus; + } + + public void setXaStatus(int xaStatus) { + this.xaStatus = xaStatus; + } + + public void onConnectFailed(Throwable t) { + if (handler instanceof MySQLConnectionHandler) { + MySQLConnectionHandler theHandler = (MySQLConnectionHandler) handler; + theHandler.connectionError(t); + } else { + ((MySQLConnectionAuthenticator) handler).connectionError(this, t); + } + } + + public String getSchema() { + return this.schema; + } + + public void setSchema(String newSchema) { + String curSchema = schema; + if (curSchema == null) { + this.schema = newSchema; + this.oldSchema = newSchema; + } else { + this.oldSchema = curSchema; + this.schema = newSchema; + } + } + + public MySQLDataSource getPool() { + return pool; + } + + public void setPool(MySQLDataSource pool) { + this.pool = pool; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + + public void setPassword(String password) { + this.password = password; + } + + public HandshakePacket getHandshake() { + return handshake; + } + + public void setHandshake(HandshakePacket handshake) { + this.handshake = handshake; + } + + public long getThreadId() { + return threadId; + } + + public void setThreadId(long threadId) { + this.threadId = threadId; + } + + public boolean isAuthenticated() { + return isAuthenticated; + } + + public void setAuthenticated(boolean isAuthenticated) { + this.isAuthenticated = isAuthenticated; + } + + public String getPassword() { + return password; + } + + public void authenticate() { + AuthPacket packet = new AuthPacket(); + packet.packetId = 1; + packet.clientFlags = clientFlags; + packet.maxPacketSize = maxPacketSize; + packet.charsetIndex = this.charsetIndex; + packet.user = user; + try { + packet.password = passwd(password, handshake); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e.getMessage()); + } + packet.database = schema; + packet.write(this); + } + + public boolean isAutocommit() { + return autocommit; + } + + public Object getAttachment() { + return attachment; + } + + public void setAttachment(Object attachment) { + this.attachment = attachment; + } + + public boolean isClosedOrQuit() { + return isClosed() || isQuit.get(); + } + + protected void sendQueryCmd(String query) { + CommandPacket packet = new CommandPacket(); + packet.packetId = 0; + packet.command = MySQLPacket.COM_QUERY; + try { + packet.arg = query.getBytes(charset); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + lastTime = TimeUtil.currentTimeMillis(); + packet.write(this); + } + + private static void getCharsetCommand(StringBuilder sb, int clientCharIndex) { + sb.append("SET names ").append(CharsetUtil.getCharset(clientCharIndex)) + .append(";"); + } + + private static void getTxIsolationCommand(StringBuilder sb, int txIsolation) { + switch (txIsolation) { + case Isolations.READ_UNCOMMITTED: + sb.append("SET SESSION TRANSACTION ISOLATION LEVEL READ UNCOMMITTED;"); + return; + case Isolations.READ_COMMITTED: + sb.append("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED;"); + return; + case Isolations.REPEATED_READ: + sb.append("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); + return; + case Isolations.SERIALIZABLE: + sb.append("SET SESSION TRANSACTION ISOLATION LEVEL SERIALIZABLE;"); + return; + default: + throw new UnknownTxIsolationException("txIsolation:" + txIsolation); + } + } + + private void getAutocommitCommand(StringBuilder sb, boolean autoCommit) { + if (autoCommit) { + sb.append("SET autocommit=1;"); + } else { + sb.append("SET autocommit=0;"); + } + } + + private static class StatusSync { + private final String schema; + private final Integer charsetIndex; + private final Integer txtIsolation; + private final Boolean autocommit; + private final AtomicInteger synCmdCount; + private final boolean xaStarted; + + public StatusSync(boolean xaStarted, String schema, + Integer charsetIndex, Integer txtIsolation, Boolean autocommit, + int synCount) { + super(); + this.xaStarted = xaStarted; + this.schema = schema; + this.charsetIndex = charsetIndex; + this.txtIsolation = txtIsolation; + this.autocommit = autocommit; + this.synCmdCount = new AtomicInteger(synCount); + } + + public boolean synAndExecuted(MySQLConnection conn) { + int remains = synCmdCount.decrementAndGet(); + if (remains == 0) {// syn command finished + this.updateConnectionInfo(conn); + conn.metaDataSyned = true; + return false; + } else if (remains < 0) { + return true; + } + return false; + } + + private void updateConnectionInfo(MySQLConnection conn) + + { + if (schema != null) { + conn.schema = schema; + conn.oldSchema = conn.schema; + } + if (charsetIndex != null) { + conn.setCharset(CharsetUtil.getCharset(charsetIndex)); + } + if (txtIsolation != null) { + conn.txIsolation = txtIsolation; + } + if (autocommit != null) { + conn.autocommit = autocommit; + } + } + + } + + /** + * @return if synchronization finished and execute-sql has already been sent + * before + */ + public boolean syncAndExcute() { + StatusSync sync = this.statusSync; + if (sync == null) { + return true; + } else { + boolean executed = sync.synAndExecuted(this); + if (executed) { + statusSync = null; + } + return executed; + } + + } + + public void execute(RouteResultsetNode rrn, ServerConnection sc, + boolean autocommit) throws UnsupportedEncodingException { + if (!modifiedSQLExecuted && rrn.isModifySQL()) { + modifiedSQLExecuted = true; + } + String xaTXID = null; + if(sc.getSession2().getXaTXID()!=null){ + xaTXID = sc.getSession2().getXaTXID()+",'"+getSchema()+"'"; + } + synAndDoExecute(xaTXID, rrn, sc.getCharsetIndex(), sc.getTxIsolation(), + autocommit); + } + + private void synAndDoExecute(String xaTxID, RouteResultsetNode rrn, + int clientCharSetIndex, int clientTxIsoLation, + boolean clientAutoCommit) { + String xaCmd = null; + + boolean conAutoComit = this.autocommit; + String conSchema = this.schema; + boolean strictTxIsolation = MycatServer.getInstance().getConfig().getSystem().isStrictTxIsolation(); + boolean expectAutocommit = false; + // 如果在非自动提交情况下,如果需要严格保证事务级别,则需做下列判断 + if (strictTxIsolation) { + expectAutocommit = isFromSlaveDB() || clientAutoCommit; + } else { + // never executed modify sql,so auto commit + expectAutocommit = (!modifiedSQLExecuted || isFromSlaveDB() || clientAutoCommit); + } + if (expectAutocommit == false && xaTxID != null && xaStatus == TxState.TX_INITIALIZE_STATE) { + //clientTxIsoLation = Isolations.SERIALIZABLE; + xaCmd = "XA START " + xaTxID + ';'; + this.xaStatus = TxState.TX_STARTED_STATE; + } + int schemaSyn = conSchema.equals(oldSchema) ? 0 : 1; + int charsetSyn = 0; + if (this.charsetIndex != clientCharSetIndex) { + //need to syn the charset of connection. + //set current connection charset to client charset. + //otherwise while sending commend to server the charset will not coincidence. + setCharset(CharsetUtil.getCharset(clientCharSetIndex)); + charsetSyn = 1; + } + int txIsoLationSyn = (txIsolation == clientTxIsoLation) ? 0 : 1; + int autoCommitSyn = (conAutoComit == expectAutocommit) ? 0 : 1; + int synCount = schemaSyn + charsetSyn + txIsoLationSyn + autoCommitSyn + (xaCmd!=null?1:0); +// if (synCount == 0 && this.xaStatus != TxState.TX_STARTED_STATE) { + if (synCount == 0 ) { + // not need syn connection + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("not need syn connection :\n" + this+"\n to send query cmd:\n"+rrn.getStatement() + +"\n in pool\n" + +this.getPool().getConfig()); + } + sendQueryCmd(rrn.getStatement()); + return; + } + CommandPacket schemaCmd = null; + StringBuilder sb = new StringBuilder(); + if (schemaSyn == 1) { + schemaCmd = getChangeSchemaCommand(conSchema); + // getChangeSchemaCommand(sb, conSchema); + } + + if (charsetSyn == 1) { + getCharsetCommand(sb, clientCharSetIndex); + } + if (txIsoLationSyn == 1) { + getTxIsolationCommand(sb, clientTxIsoLation); + } + if (autoCommitSyn == 1) { + getAutocommitCommand(sb, expectAutocommit); + } + if (xaCmd != null) { + sb.append(xaCmd); + } + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("con need syn ,total syn cmd " + synCount + + " commands " + sb.toString() + "schema change:" + + (schemaCmd != null) + " con:" + this); + } + metaDataSyned = false; + statusSync = new StatusSync(xaCmd != null, conSchema, + clientCharSetIndex, clientTxIsoLation, expectAutocommit, + synCount); + // syn schema + if (schemaCmd != null) { + schemaCmd.write(this); + } + // and our query sql to multi command at last + sb.append(rrn.getStatement()+";"); + // syn and execute others + this.sendQueryCmd(sb.toString()); + // waiting syn result... + + } + + private static CommandPacket getChangeSchemaCommand(String schema) { + CommandPacket cmd = new CommandPacket(); + cmd.packetId = 0; + cmd.command = MySQLPacket.COM_INIT_DB; + cmd.arg = schema.getBytes(); + return cmd; + } + + /** + * by wuzh ,execute a query and ignore transaction settings for performance + * + * @param query + * @throws UnsupportedEncodingException + */ + public void query(String query) throws UnsupportedEncodingException { + RouteResultsetNode rrn = new RouteResultsetNode("default", + ServerParse.SELECT, query); + + synAndDoExecute(null, rrn, this.charsetIndex, this.txIsolation, true); + + } + /** + * by zwy ,execute a query with charsetIndex + * + * @param query + * @throws UnsupportedEncodingException + */ + @Override + public void query(String query, int charsetIndex) { + RouteResultsetNode rrn = new RouteResultsetNode("default", + ServerParse.SELECT, query); + + synAndDoExecute(null, rrn, charsetIndex, this.txIsolation, true); + + } + public long getLastTime() { + return lastTime; + } + + public void setLastTime(long lastTime) { + this.lastTime = lastTime; + } + + public void quit() { + if (isQuit.compareAndSet(false, true) && !isClosed()) { + if (isAuthenticated) { + write(writeToBuffer(QuitPacket.QUIT, allocate())); + write(allocate()); + } else { + close("normal"); + } + } + } + + @Override + public void close(String reason) { + if (!isClosed.get()) { + isQuit.set(true); + ResponseHandler tmpRespHandlers= respHandler; + setResponseHandler(null); + super.close(reason); + pool.connectionClosed(this); + if (tmpRespHandlers != null) { + tmpRespHandlers.connectionClose(this, reason); + } + if( this.handler instanceof MySQLConnectionAuthenticator) { + ((MySQLConnectionAuthenticator) this.handler).connectionError(this, new Throwable(reason)); + + } + } else { + //主要起一个清理资源的作用 + super.close(reason); + } + } + + public void commit() { + + _COMMIT.write(this); + + } + + public boolean batchCmdFinished() { + batchCmdCount--; + return (batchCmdCount == 0); + } + + public void execCmd(String cmd) { + this.sendQueryCmd(cmd); + } + + public void execBatchCmd(String[] batchCmds) { + // "XA END "+xaID+";"+"XA PREPARE "+xaID + this.batchCmdCount = batchCmds.length; + StringBuilder sb = new StringBuilder(); + for (String sql : batchCmds) { + sb.append(sql).append(';'); + } + this.sendQueryCmd(sb.toString()); + } + + public void rollback() { + _ROLLBACK.write(this); + } + + public void release() { + if (metaDataSyned == false) {// indicate connection not normalfinished + // ,and + // we can't know it's syn status ,so + // close + // it + LOGGER.warn("can't sure connection syn result,so close it " + this); + this.respHandler = null; + this.close("syn status unkown "); + return; + } + metaDataSyned = true; + attachment = null; + statusSync = null; + modifiedSQLExecuted = false; + xaStatus = TxState.TX_INITIALIZE_STATE; + setResponseHandler(null); + pool.releaseChannel(this); + } + + public boolean setResponseHandler(ResponseHandler queryHandler) { + if (handler instanceof MySQLConnectionHandler) { + ((MySQLConnectionHandler) handler).setResponseHandler(queryHandler); + respHandler = queryHandler; + return true; + } else if (queryHandler != null) { + LOGGER.warn("set not MySQLConnectionHandler " + + queryHandler.getClass().getCanonicalName()); + } + return false; + } + + /** + * 写队列为空,可以继续写数据 + */ + public void writeQueueAvailable() { + if (respHandler != null) { + respHandler.writeQueueAvailable(); + } + } + + /** + * 记录sql执行信息 + */ + public void recordSql(String host, String schema, String stmt) { + // final long now = TimeUtil.currentTimeMillis(); + // if (now > this.lastTime) { + // // long time = now - this.lastTime; + // // SQLRecorder sqlRecorder = this.pool.getSqlRecorder(); + // // if (sqlRecorder.check(time)) { + // // SQLRecord recorder = new SQLRecord(); + // // recorder.host = host; + // // recorder.schema = schema; + // // recorder.statement = stmt; + // // recorder.startTime = lastTime; + // // recorder.executeTime = time; + // // recorder.dataNode = pool.getName(); + // // recorder.dataNodeIndex = pool.getIndex(); + // // sqlRecorder.add(recorder); + // // } + // } + // this.lastTime = now; + } + + private static byte[] passwd(String pass, HandshakePacket hs) + throws NoSuchAlgorithmException { + if (pass == null || pass.length() == 0) { + return null; + } + byte[] passwd = pass.getBytes(); + int sl1 = hs.seed.length; + int sl2 = hs.restOfScrambleBuff.length; + byte[] seed = new byte[sl1 + sl2]; + System.arraycopy(hs.seed, 0, seed, 0, sl1); + System.arraycopy(hs.restOfScrambleBuff, 0, seed, sl1, sl2); + return SecurityUtil.scramble411(passwd, seed); + } + + @Override + public boolean isFromSlaveDB() { + return fromSlaveDB; + } + + @Override + public boolean isBorrowed() { + return borrowed; + } + + @Override + public void setBorrowed(boolean borrowed) { + this.lastTime = TimeUtil.currentTimeMillis(); + this.borrowed = borrowed; + } + + @Override + public String toString() { + return "MySQLConnection [id=" + id + ", lastTime=" + lastTime + + ", user=" + user + + ", schema=" + schema + ", old shema=" + oldSchema + + ", borrowed=" + borrowed + ", fromSlaveDB=" + fromSlaveDB + + ", threadId=" + threadId + ", charset=" + charset + + ", txIsolation=" + txIsolation + ", autocommit=" + autocommit + + ", attachment=" + attachment + ", respHandler=" + respHandler + + ", host=" + host + ", port=" + port + ", statusSync=" + + statusSync + ", writeQueue=" + this.getWriteQueue().size() + + ", modifiedSQLExecuted=" + modifiedSQLExecuted + "]"; + } + + @Override + public boolean isModifiedSQLExecuted() { + return modifiedSQLExecuted; + } + + @Override + public int getTxIsolation() { + return txIsolation; + } + + + +} diff --git a/src/main/java/io/mycat/route/RouteResultset.java b/src/main/java/io/mycat/route/RouteResultset.java index 2e15fdec3..55b23cc5f 100644 --- a/src/main/java/io/mycat/route/RouteResultset.java +++ b/src/main/java/io/mycat/route/RouteResultset.java @@ -77,8 +77,6 @@ public final class RouteResultset implements Serializable { private boolean selectForUpdate; private boolean autoIncrement; - - private Map> subTableMaps; public boolean isSelectForUpdate() { return selectForUpdate; @@ -449,15 +447,4 @@ public void setAutoIncrement(boolean b) { public boolean getAutoIncrement() { return autoIncrement; } - - public Map> getSubTableMaps() { - return subTableMaps; - } - - public void setSubTableMaps(Map> subTableMaps) { - this.subTableMaps = subTableMaps; - } - - - } diff --git a/src/main/java/io/mycat/route/RouteResultsetNode.java b/src/main/java/io/mycat/route/RouteResultsetNode.java index 35a6292e8..f574506dc 100644 --- a/src/main/java/io/mycat/route/RouteResultsetNode.java +++ b/src/main/java/io/mycat/route/RouteResultsetNode.java @@ -1,322 +1,307 @@ -/* - * Copyright (c) 2013, OpenCloudDB/MyCAT and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software;Designed and Developed mainly by many Chinese - * opensource volunteers. you can redistribute it and/or modify it under the - * terms of the GNU General Public License version 2 only, as published by the - * Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Any questions about this component can be directed to it's project Web address - * https://code.google.com/p/opencloudb/. - * - */ -package io.mycat.route; - -import java.io.Serializable; -import java.util.Map; - -import org.apache.logging.log4j.util.Strings; - -import io.mycat.server.parser.ServerParse; -import io.mycat.sqlengine.mpp.LoadData; - -/** - * @author mycat - */ -public final class RouteResultsetNode implements Serializable , Comparable { - /** - * - */ - private static final long serialVersionUID = 1L; - private final String name; // 数据节点名称 - private String statement; // 执行的语句 - private final String srcStatement; - private final int sqlType; - private volatile boolean canRunInReadDB; - private final boolean hasBlanceFlag; - private boolean callStatement = false; // 处理call关键字 - private int limitStart; - private int limitSize; - private int totalNodeSize =0; //方便后续jdbc批量获取扩展 - private Procedure procedure; - private LoadData loadData; - private RouteResultset source; - - // 强制走 master,可以通过 RouteResultset的属性canRunInReadDB(false) - // 传给 RouteResultsetNode 来实现,但是 强制走 slave需要增加一个属性来实现: - private Boolean runOnSlave = null; // 默认null表示不施加影响, true走slave,false走master - - private String subTableName; // 分表的表名 - private Map subTableNames;//分表的表名集合 - - //迁移算法用 -2代表不是slot分片 ,-1代表扫描所有分片 - private int slot=-2; - - public RouteResultsetNode(String name, int sqlType, String srcStatement) { - this.name = name; - limitStart=0; - this.limitSize = -1; - this.sqlType = sqlType; - this.srcStatement = srcStatement; - this.statement = srcStatement; - canRunInReadDB = (sqlType == ServerParse.SELECT || sqlType == ServerParse.SHOW); - hasBlanceFlag = (statement != null) - && statement.startsWith("/*balance*/"); - } - - public Map getSubTableNames() { - return subTableNames; - } - - public void setSubTableNames(Map subTableNames) { - this.subTableNames = subTableNames; - } - - public Boolean getRunOnSlave() { - return runOnSlave; - } - public String getRunOnSlaveDebugInfo() { - return runOnSlave == null?" default ":Boolean.toString(runOnSlave); - } - public boolean isUpdateSql() { - int type=sqlType; - return ServerParse.INSERT==type||ServerParse.UPDATE==type||ServerParse.DELETE==type||ServerParse.DDL==type; - } - public void setRunOnSlave(Boolean runOnSlave) { - this.runOnSlave = runOnSlave; - } - private Map hintMap; - - public Map getHintMap() - { - return hintMap; - } - - public void setHintMap(Map hintMap) - { - this.hintMap = hintMap; - } - - public void setStatement(String statement) { - this.statement = statement; - } - - public void setCanRunInReadDB(boolean canRunInReadDB) { - this.canRunInReadDB = canRunInReadDB; - } - - public boolean getCanRunInReadDB() { - return this.canRunInReadDB; - } - - public void resetStatement() { - this.statement = srcStatement; - } - - /** - * 这里的逻辑是为了优化,实现:非业务sql可以在负载均衡走slave的效果。因为业务sql一般是非自动提交, - * 而非业务sql一般默认是自动提交,比如mysql client,还有SQLJob, heartbeat都可以使用 - * 了Leader-us优化的query函数,该函数实现为自动提交; - * - * 在非自动提交的情况下(有事物),除非使用了 balance 注解的情况下,才可以走slave. - * - * 当然还有一个大前提,必须是 select 或者 show 语句(canRunInReadDB=true) - * @param autocommit - * @return - */ - public boolean canRunnINReadDB(boolean autocommit) { - return canRunInReadDB && ( autocommit || (!autocommit && hasBlanceFlag) ); - } - -// public boolean canRunnINReadDB(boolean autocommit) { -// return canRunInReadDB && autocommit && !hasBlanceFlag -// || canRunInReadDB && !autocommit && hasBlanceFlag; -// } - public Procedure getProcedure() - { - return procedure; - } - - public int getSlot() { - return slot; - } - - public void setSlot(int slot) { - this.slot = slot; - } - - public void setProcedure(Procedure procedure) - { - this.procedure = procedure; - } - - public boolean isCallStatement() - { - return callStatement; - } - - public void setCallStatement(boolean callStatement) - { - this.callStatement = callStatement; - } - public String getName() { - return name; - } - - public int getSqlType() { - return sqlType; - } - - public String getStatement() { - return statement; - } - - public int getLimitStart() - { - return limitStart; - } - - public void setLimitStart(int limitStart) - { - this.limitStart = limitStart; - } - - public int getLimitSize() - { - return limitSize; - } - - public void setLimitSize(int limitSize) - { - this.limitSize = limitSize; - } - - public int getTotalNodeSize() - { - return totalNodeSize; - } - - public void setTotalNodeSize(int totalNodeSize) - { - this.totalNodeSize = totalNodeSize; - } - - public LoadData getLoadData() - { - return loadData; - } - - public void setLoadData(LoadData loadData) - { - this.loadData = loadData; - } - - @Override - public int hashCode() { - return name.hashCode(); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj instanceof RouteResultsetNode) { - RouteResultsetNode rrn = (RouteResultsetNode) obj; - if(subTableName!=null){ - if (equals(name, rrn.getName()) && equals(subTableName, rrn.getSubTableName())) { - return true; - } - }else{ - if (equals(name, rrn.getName())) { - return true; - } - } - } - return false; - } - - @Override - public String toString() { - StringBuilder s = new StringBuilder(); - s.append(name); - s.append('{').append(statement).append('}'); - return s.toString(); - } - - private static boolean equals(String str1, String str2) { - if (str1 == null) { - return str2 == null; - } - return str1.equals(str2); - } - - public String getSubTableName() { - return this.subTableName; - } - - public void setSubTableName(String subTableName) { - this.subTableName = subTableName; - } - - public boolean isModifySQL() { - return !canRunInReadDB; - } - public boolean isDisctTable() { - if(subTableName!=null && !subTableName.equals("")){ - return true; - }; - return false; - } - - - @Override - public int compareTo(RouteResultsetNode obj) { - if(obj == null) { - return 1; - } - if(this.name == null) { - return -1; - } - if(obj.name == null) { - return 1; - } - int c = this.name.compareTo(obj.name); - if(!this.isDisctTable()){ - return c; - }else{ - if(c==0){ - String subTableName = obj.subTableName; - - if (Strings.isBlank(subTableName)) { - return 1; - } - return this.subTableName.compareTo(subTableName); - } - return c; - } - } - - public boolean isHasBlanceFlag() { - return hasBlanceFlag; - } - - public RouteResultset getSource() { - return source; - } - - public void setSource(RouteResultset source) { - this.source = source; - } -} +/* + * Copyright (c) 2013, OpenCloudDB/MyCAT and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software;Designed and Developed mainly by many Chinese + * opensource volunteers. you can redistribute it and/or modify it under the + * terms of the GNU General Public License version 2 only, as published by the + * Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Any questions about this component can be directed to it's project Web address + * https://code.google.com/p/opencloudb/. + * + */ +package io.mycat.route; + +import java.io.Serializable; +import java.util.Map; +import java.util.Set; + +import io.mycat.server.parser.ServerParse; +import io.mycat.sqlengine.mpp.LoadData; + +/** + * @author mycat + */ +public final class RouteResultsetNode implements Serializable , Comparable { + /** + * + */ + private static final long serialVersionUID = 1L; + private final String name; // 数据节点名称 + private String statement; // 执行的语句 + private final String srcStatement; + private final int sqlType; + private volatile boolean canRunInReadDB; + private final boolean hasBlanceFlag; + private boolean callStatement = false; // 处理call关键字 + private int limitStart; + private int limitSize; + private int totalNodeSize =0; //方便后续jdbc批量获取扩展 + private Procedure procedure; + private LoadData loadData; + private RouteResultset source; + + // 强制走 master,可以通过 RouteResultset的属性canRunInReadDB(false) + // 传给 RouteResultsetNode 来实现,但是 强制走 slave需要增加一个属性来实现: + private Boolean runOnSlave = null; // 默认null表示不施加影响, true走slave,false走master + + private String subTableName; // 分表的表名 + + //迁移算法用 -2代表不是slot分片 ,-1代表扫描所有分片 + private int slot=-2; + + public RouteResultsetNode(String name, int sqlType, String srcStatement) { + this.name = name; + limitStart=0; + this.limitSize = -1; + this.sqlType = sqlType; + this.srcStatement = srcStatement; + this.statement = srcStatement; + canRunInReadDB = (sqlType == ServerParse.SELECT || sqlType == ServerParse.SHOW); + hasBlanceFlag = (statement != null) + && statement.startsWith("/*balance*/"); + } + + public Boolean getRunOnSlave() { + return runOnSlave; + } + public String getRunOnSlaveDebugInfo() { + return runOnSlave == null?" default ":Boolean.toString(runOnSlave); + } + public boolean isUpdateSql() { + int type=sqlType; + return ServerParse.INSERT==type||ServerParse.UPDATE==type||ServerParse.DELETE==type||ServerParse.DDL==type; + } + public void setRunOnSlave(Boolean runOnSlave) { + this.runOnSlave = runOnSlave; + } + private Map hintMap; + + public Map getHintMap() + { + return hintMap; + } + + public void setHintMap(Map hintMap) + { + this.hintMap = hintMap; + } + + public void setStatement(String statement) { + this.statement = statement; + } + + public void setCanRunInReadDB(boolean canRunInReadDB) { + this.canRunInReadDB = canRunInReadDB; + } + + public boolean getCanRunInReadDB() { + return this.canRunInReadDB; + } + + public void resetStatement() { + this.statement = srcStatement; + } + + /** + * 这里的逻辑是为了优化,实现:非业务sql可以在负载均衡走slave的效果。因为业务sql一般是非自动提交, + * 而非业务sql一般默认是自动提交,比如mysql client,还有SQLJob, heartbeat都可以使用 + * 了Leader-us优化的query函数,该函数实现为自动提交; + * + * 在非自动提交的情况下(有事物),除非使用了 balance 注解的情况下,才可以走slave. + * + * 当然还有一个大前提,必须是 select 或者 show 语句(canRunInReadDB=true) + * @param autocommit + * @return + */ + public boolean canRunnINReadDB(boolean autocommit) { + return canRunInReadDB && ( autocommit || (!autocommit && hasBlanceFlag) ); + } + +// public boolean canRunnINReadDB(boolean autocommit) { +// return canRunInReadDB && autocommit && !hasBlanceFlag +// || canRunInReadDB && !autocommit && hasBlanceFlag; +// } + public Procedure getProcedure() + { + return procedure; + } + + public int getSlot() { + return slot; + } + + public void setSlot(int slot) { + this.slot = slot; + } + + public void setProcedure(Procedure procedure) + { + this.procedure = procedure; + } + + public boolean isCallStatement() + { + return callStatement; + } + + public void setCallStatement(boolean callStatement) + { + this.callStatement = callStatement; + } + public String getName() { + return name; + } + + public int getSqlType() { + return sqlType; + } + + public String getStatement() { + return statement; + } + + public int getLimitStart() + { + return limitStart; + } + + public void setLimitStart(int limitStart) + { + this.limitStart = limitStart; + } + + public int getLimitSize() + { + return limitSize; + } + + public void setLimitSize(int limitSize) + { + this.limitSize = limitSize; + } + + public int getTotalNodeSize() + { + return totalNodeSize; + } + + public void setTotalNodeSize(int totalNodeSize) + { + this.totalNodeSize = totalNodeSize; + } + + public LoadData getLoadData() + { + return loadData; + } + + public void setLoadData(LoadData loadData) + { + this.loadData = loadData; + } + + @Override + public int hashCode() { + return name.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj instanceof RouteResultsetNode) { + RouteResultsetNode rrn = (RouteResultsetNode) obj; + if(subTableName!=null){ + if (equals(name, rrn.getName()) && equals(subTableName, rrn.getSubTableName())) { + return true; + } + }else{ + if (equals(name, rrn.getName())) { + return true; + } + } + } + return false; + } + + @Override + public String toString() { + StringBuilder s = new StringBuilder(); + s.append(name); + s.append('{').append(statement).append('}'); + return s.toString(); + } + + private static boolean equals(String str1, String str2) { + if (str1 == null) { + return str2 == null; + } + return str1.equals(str2); + } + + public String getSubTableName() { + return this.subTableName; + } + + public void setSubTableName(String subTableName) { + this.subTableName = subTableName; + } + + public boolean isModifySQL() { + return !canRunInReadDB; + } + public boolean isDisctTable() { + if(subTableName!=null && !subTableName.equals("")){ + return true; + }; + return false; + } + + + @Override + public int compareTo(RouteResultsetNode obj) { + if(obj == null) { + return 1; + } + if(this.name == null) { + return -1; + } + if(obj.name == null) { + return 1; + } + int c = this.name.compareTo(obj.name); + if(!this.isDisctTable()){ + return c; + }else{ + if(c==0){ + return this.subTableName.compareTo(obj.subTableName); + } + return c; + } + } + + public boolean isHasBlanceFlag() { + return hasBlanceFlag; + } + + public RouteResultset getSource() { + return source; + } + + public void setSource(RouteResultset source) { + this.source = source; + } +} diff --git a/src/main/java/io/mycat/route/impl/DruidMycatRouteStrategy.java b/src/main/java/io/mycat/route/impl/DruidMycatRouteStrategy.java index 401c48535..ba33f1e0c 100644 --- a/src/main/java/io/mycat/route/impl/DruidMycatRouteStrategy.java +++ b/src/main/java/io/mycat/route/impl/DruidMycatRouteStrategy.java @@ -1,774 +1,756 @@ -package io.mycat.route.impl; - -import java.sql.SQLNonTransientException; -import java.sql.SQLSyntaxErrorException; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.SortedSet; -import java.util.TreeSet; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.alibaba.druid.sql.SQLUtils; -import com.alibaba.druid.sql.ast.SQLExpr; -import com.alibaba.druid.sql.ast.SQLObject; -import com.alibaba.druid.sql.ast.SQLStatement; -import com.alibaba.druid.sql.ast.expr.SQLAllExpr; -import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; -import com.alibaba.druid.sql.ast.expr.SQLExistsExpr; -import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; -import com.alibaba.druid.sql.ast.expr.SQLInSubQueryExpr; -import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; -import com.alibaba.druid.sql.ast.statement.SQLDeleteStatement; -import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; -import com.alibaba.druid.sql.ast.statement.SQLInsertStatement; -import com.alibaba.druid.sql.ast.statement.SQLSelect; -import com.alibaba.druid.sql.ast.statement.SQLSelectQuery; -import com.alibaba.druid.sql.ast.statement.SQLSelectStatement; -import com.alibaba.druid.sql.ast.statement.SQLTableSource; -import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement; -import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; -import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlReplaceStatement; -import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; -import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; -import com.alibaba.druid.sql.parser.SQLStatementParser; -import com.alibaba.druid.stat.TableStat.Relationship; -import com.google.common.base.Strings; - -import io.mycat.MycatServer; -import io.mycat.backend.mysql.nio.handler.MiddlerQueryResultHandler; -import io.mycat.backend.mysql.nio.handler.MiddlerResultHandler; -import io.mycat.backend.mysql.nio.handler.SecondHandler; -import io.mycat.cache.LayerCachePool; -import io.mycat.config.ErrorCode; -import io.mycat.config.model.SchemaConfig; -import io.mycat.config.model.TableConfig; -import io.mycat.config.model.rule.RuleConfig; -import io.mycat.route.RouteResultset; -import io.mycat.route.RouteResultsetNode; -import io.mycat.route.function.SlotFunction; -import io.mycat.route.impl.middlerResultStrategy.BinaryOpResultHandler; -import io.mycat.route.impl.middlerResultStrategy.InSubQueryResultHandler; -import io.mycat.route.impl.middlerResultStrategy.RouteMiddlerReaultHandler; -import io.mycat.route.impl.middlerResultStrategy.SQLAllResultHandler; -import io.mycat.route.impl.middlerResultStrategy.SQLExistsResultHandler; -import io.mycat.route.impl.middlerResultStrategy.SQLQueryResultHandler; -import io.mycat.route.parser.druid.DruidParser; -import io.mycat.route.parser.druid.DruidParserFactory; -import io.mycat.route.parser.druid.DruidShardingParseInfo; -import io.mycat.route.parser.druid.MycatSchemaStatVisitor; -import io.mycat.route.parser.druid.MycatStatementParser; -import io.mycat.route.parser.druid.RouteCalculateUnit; -import io.mycat.route.parser.util.ParseUtil; -import io.mycat.route.util.RouterUtil; -import io.mycat.server.NonBlockingSession; -import io.mycat.server.ServerConnection; -import io.mycat.server.parser.ServerParse; - -public class DruidMycatRouteStrategy extends AbstractRouteStrategy { - - public static final Logger LOGGER = LoggerFactory.getLogger(DruidMycatRouteStrategy.class); - - private static Map,RouteMiddlerReaultHandler> middlerResultHandler = new HashMap<>(); - - static{ - middlerResultHandler.put(SQLQueryExpr.class, new SQLQueryResultHandler()); - middlerResultHandler.put(SQLBinaryOpExpr.class, new BinaryOpResultHandler()); - middlerResultHandler.put(SQLInSubQueryExpr.class, new InSubQueryResultHandler()); - middlerResultHandler.put(SQLExistsExpr.class, new SQLExistsResultHandler()); - middlerResultHandler.put(SQLAllExpr.class, new SQLAllResultHandler()); - } - - - @Override - public RouteResultset routeNormalSqlWithAST(SchemaConfig schema, - String stmt, RouteResultset rrs,String charset, - LayerCachePool cachePool,int sqlType,ServerConnection sc) throws SQLNonTransientException { - - /** - * 只有mysql时只支持mysql语法 - */ - SQLStatementParser parser = null; - if (schema.isNeedSupportMultiDBType()) { - parser = new MycatStatementParser(stmt); - } else { - parser = new MySqlStatementParser(stmt); - } - - MycatSchemaStatVisitor visitor = null; - SQLStatement statement; - - /** - * 解析出现问题统一抛SQL语法错误 - */ - try { - statement = parser.parseStatement(); - visitor = new MycatSchemaStatVisitor(); - } catch (Exception t) { - LOGGER.error("DruidMycatRouteStrategyError", t); - throw new SQLSyntaxErrorException(t); - } - - /** - * 检验unsupported statement - */ - checkUnSupportedStatement(statement); - - DruidParser druidParser = DruidParserFactory.create(schema, statement, visitor); - druidParser.parser(schema, rrs, statement, stmt,cachePool,visitor); - DruidShardingParseInfo ctx= druidParser.getCtx() ; - rrs.setTables(ctx.getTables()); - - if(visitor.isSubqueryRelationOr()){ - String err = "In subQuery,the or condition is not supported."; - LOGGER.error(err); - throw new SQLSyntaxErrorException(err); - } - - /* 按照以下情况路由 - 1.2.1 可以直接路由. - 1.2.2 两个表夸库join的sql.调用calat - 1.2.3 需要先执行subquery 的sql.把subquery拆分出来.获取结果后,与outerquery - */ - - //add huangyiming 分片规则不一样的且表中带查询条件的则走Catlet - List tables = ctx.getTables(); - SchemaConfig schemaConf = MycatServer.getInstance().getConfig().getSchemas().get(schema.getName()); - int index = 0; - RuleConfig firstRule = null; - boolean directRoute = true; - Set firstDataNodes = new HashSet(); - Map tconfigs = schemaConf==null?null:schemaConf.getTables(); - - Map rulemap = new HashMap<>(); - if(tconfigs!=null){ - for(String tableName : tables){ - TableConfig tc = tconfigs.get(tableName); - if(tc == null){ - //add 别名中取 - Map tableAliasMap = ctx.getTableAliasMap(); - if(tableAliasMap !=null && tableAliasMap.get(tableName) !=null){ - tc = schemaConf.getTables().get(tableAliasMap.get(tableName)); - } - } - - if(index == 0){ - if(tc !=null){ - firstRule= tc.getRule(); - //没有指定分片规则时,不做处理 - if(firstRule==null){ - continue; - } - firstDataNodes.addAll(tc.getDataNodes()); - rulemap.put(tc.getName(), firstRule); - } - }else{ - if(tc !=null){ - //ER关系表的时候是可能存在字表中没有tablerule的情况,所以加上判断 - RuleConfig ruleCfg = tc.getRule(); - if(ruleCfg==null){ //没有指定分片规则时,不做处理 - continue; - } - Set dataNodes = new HashSet(); - dataNodes.addAll(tc.getDataNodes()); - rulemap.put(tc.getName(), ruleCfg); - //如果匹配规则不相同或者分片的datanode不相同则需要走子查询处理 - if(firstRule!=null&&((ruleCfg !=null && !ruleCfg.getRuleAlgorithm().equals(firstRule.getRuleAlgorithm()) )||( !dataNodes.equals(firstDataNodes)))){ - directRoute = false; - break; - } - } - } - index++; - } - } - - RouteResultset rrsResult = rrs; - if(directRoute){ //直接路由 - if(!RouterUtil.isAllGlobalTable(ctx, schemaConf)){ - if(rulemap.size()>1&&!checkRuleField(rulemap,visitor)){ - String err = "In case of slice table,there is no rule field in the relationship condition!"; - LOGGER.error(err); - throw new SQLSyntaxErrorException(err); - } - } - rrsResult = directRoute(rrs,ctx,schema,druidParser,statement,cachePool); - }else{ - int subQuerySize = visitor.getSubQuerys().size(); - if(subQuerySize==0&&ctx.getTables().size()==2){ //两表关联,考虑使用catlet - if(!visitor.getRelationships().isEmpty()){ - rrs.setCacheAble(false); - rrs.setFinishedRoute(true); - rrsResult = catletRoute(schema,ctx.getSql(),charset,sc); - }else{ - rrsResult = directRoute(rrs,ctx,schema,druidParser,statement,cachePool); - } - }else if(subQuerySize==1){ //只涉及一张表的子查询,使用 MiddlerResultHandler 获取中间结果后,改写原有 sql 继续执行 TODO 后期可能会考虑多个子查询的情况. - SQLSelect sqlselect = visitor.getSubQuerys().iterator().next(); - if(!visitor.getRelationships().isEmpty()){ // 当 inner query 和 outer query 有关联条件时,暂不支持 - String err = "In case of slice table,sql have different rules,the relationship condition is not supported."; - LOGGER.error(err); - throw new SQLSyntaxErrorException(err); - }else{ - SQLSelectQuery sqlSelectQuery = sqlselect.getQuery(); - if(((MySqlSelectQueryBlock)sqlSelectQuery).getFrom() instanceof SQLExprTableSource) { - rrs.setCacheAble(false); - rrs.setFinishedRoute(true); - rrsResult = middlerResultRoute(schema,charset,sqlselect,sqlType,statement,sc); - } - } - }else if(subQuerySize >=2){ - String err = "In case of slice table,sql has different rules,currently only one subQuery is supported."; - LOGGER.error(err); - throw new SQLSyntaxErrorException(err); - } - } - return rrsResult; - } - - /** - * 子查询中存在关联查询的情况下,检查关联字段是否是分片字段 - * @param rulemap - * @param ships - * @return - */ - private boolean checkRuleField(Map rulemap,MycatSchemaStatVisitor visitor){ - - if(!MycatServer.getInstance().getConfig().getSystem().isSubqueryRelationshipCheck()){ - return true; - } - - Set ships = visitor.getRelationships(); - Iterator iter = ships.iterator(); - while(iter.hasNext()){ - Relationship ship = iter.next(); - String lefttable = ship.getLeft().getTable().toUpperCase(); - String righttable = ship.getRight().getTable().toUpperCase(); - // 如果是同一个表中的关联条件,不做处理 - if(lefttable.equals(righttable)){ - return true; - } - RuleConfig leftconfig = rulemap.get(lefttable); - RuleConfig rightconfig = rulemap.get(righttable); - - if(null!=leftconfig&&null!=rightconfig - &&leftconfig.equals(rightconfig) - &&leftconfig.getColumn().equals(ship.getLeft().getName().toUpperCase()) - &&rightconfig.getColumn().equals(ship.getRight().getName().toUpperCase())){ - return true; - } - } - return false; - } - - private RouteResultset middlerResultRoute(final SchemaConfig schema,final String charset,final SQLSelect sqlselect, - final int sqlType,final SQLStatement statement,final ServerConnection sc){ - - final String middlesql = SQLUtils.toMySqlString(sqlselect); - - MiddlerResultHandler middlerResultHandler = new MiddlerQueryResultHandler<>(new SecondHandler() { - @Override - public void doExecute(List param) { - sc.getSession2().setMiddlerResultHandler(null); - String sqls = null; - // 路由计算 - RouteResultset rrs = null; - try { - - sqls = buildSql(statement,sqlselect,param); - rrs = MycatServer - .getInstance() - .getRouterservice() - .route(MycatServer.getInstance().getConfig().getSystem(), - schema, sqlType,sqls.toLowerCase(), charset,sc ); - - } catch (Exception e) { - StringBuilder s = new StringBuilder(); - LOGGER.warn(s.append(this).append(sqls).toString() + " err:" + e.toString(),e); - String msg = e.getMessage(); - sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR, msg == null ? e.getClass().getSimpleName() : msg); - return; - } - NonBlockingSession noBlockSession = new NonBlockingSession(sc.getSession2().getSource()); - noBlockSession.setMiddlerResultHandler(null); - //session的预编译标示传递 - noBlockSession.setPrepared(sc.getSession2().isPrepared()); - if (rrs != null) { - noBlockSession.setCanClose(false); - noBlockSession.execute(rrs, ServerParse.SELECT); - } - } - } ); - sc.getSession2().setMiddlerResultHandler(middlerResultHandler); - sc.getSession2().setCanClose(false); - - // 路由计算 - RouteResultset rrs = null; - try { - rrs = MycatServer - .getInstance() - .getRouterservice() - .route(MycatServer.getInstance().getConfig().getSystem(), - schema, ServerParse.SELECT, middlesql, charset, sc); - - } catch (Exception e) { - StringBuilder s = new StringBuilder(); - LOGGER.warn(s.append(this).append(middlesql).toString() + " err:" + e.toString(),e); - String msg = e.getMessage(); - sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR, msg == null ? e.getClass().getSimpleName() : msg); - return null; - } - - if(rrs!=null){ - rrs.setCacheAble(false); - } - return rrs; - } - - /** - * 获取子查询执行结果后,改写原始sql 继续执行. - * @param statement - * @param sqlselect - * @param param - * @return - */ - private String buildSql(SQLStatement statement,SQLSelect sqlselect,List param){ - - SQLObject parent = sqlselect.getParent(); - RouteMiddlerReaultHandler handler = middlerResultHandler.get(parent.getClass()); - if(handler==null){ - throw new UnsupportedOperationException(parent.getClass()+" current is not supported "); - } - return handler.dohandler(statement, sqlselect, parent, param); - } - - /** - * 两个表的情况,catlet - * @param schema - * @param stmt - * @param charset - * @param sc - * @return - */ - private RouteResultset catletRoute(SchemaConfig schema,String stmt,String charset,ServerConnection sc){ - RouteResultset rrs = null; - try { - rrs = MycatServer - .getInstance() - .getRouterservice() - .route(MycatServer.getInstance().getConfig().getSystem(), - schema, ServerParse.SELECT, "/*!mycat:catlet=io.mycat.catlets.ShareJoin */ "+stmt, charset, sc); - - }catch(Exception e){ - - } - return rrs; - } - - /** - * 直接结果路由 - * @param rrs - * @param ctx - * @param schema - * @param druidParser - * @param statement - * @param cachePool - * @return - * @throws SQLNonTransientException - */ - private RouteResultset directRoute(RouteResultset rrs,DruidShardingParseInfo ctx,SchemaConfig schema, - DruidParser druidParser,SQLStatement statement,LayerCachePool cachePool) throws SQLNonTransientException{ - - //改写sql:如insert语句主键自增长, 在直接结果路由的情况下,进行sql 改写处理 - druidParser.changeSql(schema, rrs, statement,cachePool); - - /** - * DruidParser 解析过程中已完成了路由的直接返回 - */ - if ( rrs.isFinishedRoute() ) { - return rrs; - } - - /** - * 没有from的select语句或其他 - */ - if((ctx.getTables() == null || ctx.getTables().size() == 0)&&(ctx.getTableAliasMap()==null||ctx.getTableAliasMap().isEmpty())) - { - return RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), druidParser.getCtx().getSql()); - } - - if(druidParser.getCtx().getRouteCalculateUnits().size() == 0) { - RouteCalculateUnit routeCalculateUnit = new RouteCalculateUnit(); - druidParser.getCtx().addRouteCalculateUnit(routeCalculateUnit); - } - - SortedSet nodeSet = new TreeSet(); - boolean isAllGlobalTable = RouterUtil.isAllGlobalTable(ctx, schema); - for(RouteCalculateUnit unit: druidParser.getCtx().getRouteCalculateUnits()) { - RouteResultset rrsTmp = RouterUtil.tryRouteForTables(schema, druidParser.getCtx(), unit, rrs, isSelect(statement), cachePool); - if(rrsTmp != null&&rrsTmp.getNodes()!=null) { - for(RouteResultsetNode node :rrsTmp.getNodes()) { - nodeSet.add(node); - } - } - if(isAllGlobalTable) {//都是全局表时只计算一遍路由 - break; - } - } - - RouteResultsetNode[] nodes = new RouteResultsetNode[nodeSet.size()]; - int i = 0; - for (RouteResultsetNode aNodeSet : nodeSet) { - nodes[i] = aNodeSet; - if(statement instanceof MySqlInsertStatement &&ctx.getTables().size()==1&&schema.getTables().containsKey(ctx.getTables().get(0))) { - RuleConfig rule = schema.getTables().get(ctx.getTables().get(0)).getRule(); - if(rule!=null&& rule.getRuleAlgorithm() instanceof SlotFunction){ - aNodeSet.setStatement(ParseUtil.changeInsertAddSlot(aNodeSet.getStatement(),aNodeSet.getSlot())); - } - } - i++; - } - rrs.setNodes(nodes); - - //分表 - /** - * subTables="t_order$1-2,t_order3" - *目前分表 1.6 开始支持 幵丏 dataNode 在分表条件下只能配置一个,分表条件下不支持join。 - */ - if(rrs.isDistTable()){ - return this.routeDisTable(statement,rrs); - } - return rrs; - } - - private SQLExprTableSource getDisTable(SQLTableSource tableSource,RouteResultsetNode node) throws SQLSyntaxErrorException{ - if(node.getSubTableName()==null){ - String msg = " sub table not exists for " + node.getName() + " on " + tableSource; - LOGGER.error("DruidMycatRouteStrategyError " + msg); - throw new SQLSyntaxErrorException(msg); - } - - SQLIdentifierExpr sqlIdentifierExpr = new SQLIdentifierExpr(); - sqlIdentifierExpr.setParent(tableSource.getParent()); - sqlIdentifierExpr.setName(node.getSubTableName()); - SQLExprTableSource from2 = new SQLExprTableSource(sqlIdentifierExpr); - return from2; - } - - private RouteResultset routeDisTable(SQLStatement statement, RouteResultset rrs) throws SQLSyntaxErrorException{ - SQLTableSource tableSource = null; - if(statement instanceof SQLInsertStatement) { - SQLInsertStatement insertStatement = (SQLInsertStatement) statement; - tableSource = insertStatement.getTableSource(); - for (RouteResultsetNode node : rrs.getNodes()) { - SQLExprTableSource from2 = getDisTable(tableSource, node); - insertStatement.setTableSource(from2); - node.setStatement(insertStatement.toString()); - } - } - if(statement instanceof SQLDeleteStatement) { - SQLDeleteStatement deleteStatement = (SQLDeleteStatement) statement; - tableSource = deleteStatement.getTableSource(); - SQLTableSource from = deleteStatement.getFrom(); - for (RouteResultsetNode node : rrs.getNodes()) { - SQLExprTableSource from2 = getDisTable(tableSource, node); - - if (from == null) { - from2.setAlias(tableSource.toString()); - deleteStatement.setFrom(from2); - } else { - String alias = from.getAlias(); - from2.setAlias(alias); - deleteStatement.setFrom(from2); - } - - node.setStatement(deleteStatement.toString()); - } - } - if(statement instanceof SQLUpdateStatement) { - SQLUpdateStatement updateStatement = (SQLUpdateStatement) statement; - tableSource = updateStatement.getTableSource(); - - String alias = tableSource.getAlias(); - SQLExprTableSource exprSource = (SQLExprTableSource) tableSource; - SQLIdentifierExpr expr = (SQLIdentifierExpr) exprSource.getExpr(); - alias = alias == null ? expr.getName() : alias; - - for (RouteResultsetNode node : rrs.getNodes()) { - SQLExprTableSource from2 = getDisTable(tableSource, node); - from2.setAlias(alias); - updateStatement.setTableSource(from2); - node.setStatement(updateStatement.toString()); - } - } - - return rrs; - } - - /** - * SELECT 语句 - */ - private boolean isSelect(SQLStatement statement) { - if(statement instanceof SQLSelectStatement) { - return true; - } - return false; - } - - /** - * 检验不支持的SQLStatement类型 :不支持的类型直接抛SQLSyntaxErrorException异常 - * @param statement - * @throws SQLSyntaxErrorException - */ - private void checkUnSupportedStatement(SQLStatement statement) throws SQLSyntaxErrorException { - //不支持replace语句 - if(statement instanceof MySqlReplaceStatement) { - throw new SQLSyntaxErrorException(" ReplaceStatement can't be supported,use insert into ...on duplicate key update... instead "); - } - } - - /** - * 分析 SHOW SQL - */ - @Override - public RouteResultset analyseShowSQL(SchemaConfig schema, - RouteResultset rrs, String stmt) throws SQLSyntaxErrorException { - - String upStmt = stmt.toUpperCase(); - int tabInd = upStmt.indexOf(" TABLES"); - if (tabInd > 0) {// show tables - int[] nextPost = RouterUtil.getSpecPos(upStmt, 0); - if (nextPost[0] > 0) {// remove db info - int end = RouterUtil.getSpecEndPos(upStmt, tabInd); - if (upStmt.indexOf(" FULL") > 0) { - stmt = "SHOW FULL TABLES" + stmt.substring(end); - } else { - stmt = "SHOW TABLES" + stmt.substring(end); - } - } - String defaultNode= schema.getDataNode(); - if(!Strings.isNullOrEmpty(defaultNode)) - { - return RouterUtil.routeToSingleNode(rrs, defaultNode, stmt); - } - return RouterUtil.routeToMultiNode(false, rrs, schema.getMetaDataNodes(), stmt); - } - - /** - * show index or column - */ - int[] indx = RouterUtil.getSpecPos(upStmt, 0); - if (indx[0] > 0) { - /** - * has table - */ - int[] repPos = { indx[0] + indx[1], 0 }; - String tableName = RouterUtil.getShowTableName(stmt, repPos); - /** - * IN DB pattern - */ - int[] indx2 = RouterUtil.getSpecPos(upStmt, indx[0] + indx[1] + 1); - if (indx2[0] > 0) {// find LIKE OR WHERE - repPos[1] = RouterUtil.getSpecEndPos(upStmt, indx2[0] + indx2[1]); - - } - stmt = stmt.substring(0, indx[0]) + " FROM " + tableName + stmt.substring(repPos[1]); - RouterUtil.routeForTableMeta(rrs, schema, tableName, stmt); - return rrs; - - } - - /** - * show create table tableName - */ - int[] createTabInd = RouterUtil.getCreateTablePos(upStmt, 0); - if (createTabInd[0] > 0) { - int tableNameIndex = createTabInd[0] + createTabInd[1]; - if (upStmt.length() > tableNameIndex) { - String tableName = stmt.substring(tableNameIndex).trim(); - int ind2 = tableName.indexOf('.'); - if (ind2 > 0) { - tableName = tableName.substring(ind2 + 1); - } - RouterUtil.routeForTableMeta(rrs, schema, tableName, stmt); - return rrs; - } - } - - return RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), stmt); - } - - -// /** -// * 为一个表进行条件路由 -// * @param schema -// * @param tablesAndConditions -// * @param tablesRouteMap -// * @throws SQLNonTransientException -// */ -// private static RouteResultset findRouteWithcConditionsForOneTable(SchemaConfig schema, RouteResultset rrs, -// Map> conditions, String tableName, String sql) throws SQLNonTransientException { -// boolean cache = rrs.isCacheAble(); -// //为分库表找路由 -// tableName = tableName.toUpperCase(); -// TableConfig tableConfig = schema.getTables().get(tableName); -// //全局表或者不分库的表略过(全局表后面再计算) -// if(tableConfig.isGlobalTable()) { -// return null; -// } else {//非全局表 -// Set routeSet = new HashSet(); -// String joinKey = tableConfig.getJoinKey(); -// for(Map.Entry> condition : conditions.entrySet()) { -// String colName = condition.getKey(); -// //条件字段是拆分字段 -// if(colName.equals(tableConfig.getPartitionColumn())) { -// Set columnPairs = condition.getValue(); -// -// for(ColumnRoutePair pair : columnPairs) { -// if(pair.colValue != null) { -// Integer nodeIndex = tableConfig.getRule().getRuleAlgorithm().calculate(pair.colValue); -// if(nodeIndex == null) { -// String msg = "can't find any valid datanode :" + tableConfig.getName() -// + " -> " + tableConfig.getPartitionColumn() + " -> " + pair.colValue; -// LOGGER.warn(msg); -// throw new SQLNonTransientException(msg); -// } -// String node = tableConfig.getDataNodes().get(nodeIndex); -// if(node != null) {//找到一个路由节点 -// routeSet.add(node); -// } -// } -// if(pair.rangeValue != null) { -// Integer[] nodeIndexs = tableConfig.getRule().getRuleAlgorithm() -// .calculateRange(pair.rangeValue.beginValue.toString(), pair.rangeValue.endValue.toString()); -// for(Integer idx : nodeIndexs) { -// String node = tableConfig.getDataNodes().get(idx); -// if(node != null) {//找到一个路由节点 -// routeSet.add(node); -// } -// } -// } -// } -// } else if(joinKey != null && joinKey.equals(colName)) { -// Set dataNodeSet = RouterUtil.ruleCalculate( -// tableConfig.getParentTC(), condition.getValue()); -// if (dataNodeSet.isEmpty()) { -// throw new SQLNonTransientException( -// "parent key can't find any valid datanode "); -// } -// if (LOGGER.isDebugEnabled()) { -// LOGGER.debug("found partion nodes (using parent partion rule directly) for child table to update " -// + Arrays.toString(dataNodeSet.toArray()) + " sql :" + sql); -// } -// if (dataNodeSet.size() > 1) { -// return RouterUtil.routeToMultiNode(rrs.isCacheAble(), rrs, schema.getAllDataNodes(), sql); -// } else { -// rrs.setCacheAble(true); -// return RouterUtil.routeToSingleNode(rrs, dataNodeSet.iterator().next(), sql); -// } -// } else {//条件字段不是拆分字段也不是join字段,略过 -// continue; -// -// } -// } -// return RouterUtil.routeToMultiNode(cache, rrs, routeSet, sql); -// -// } -// -// } - - public RouteResultset routeSystemInfo(SchemaConfig schema, int sqlType, - String stmt, RouteResultset rrs) throws SQLSyntaxErrorException { - switch(sqlType){ - case ServerParse.SHOW:// if origSQL is like show tables - return analyseShowSQL(schema, rrs, stmt); - case ServerParse.SELECT://if origSQL is like select @@ - int index = stmt.indexOf("@@"); - if(index > 0 && "SELECT".equals(stmt.substring(0, index).trim().toUpperCase())){ - return analyseDoubleAtSgin(schema, rrs, stmt); - } - break; - case ServerParse.DESCRIBE:// if origSQL is meta SQL, such as describe table - int ind = stmt.indexOf(' '); - stmt = stmt.trim(); - return analyseDescrSQL(schema, rrs, stmt, ind + 1); - } - return null; - } - - /** - * 对Desc语句进行分析 返回数据路由集合 - * * - * @param schema 数据库名 - * @param rrs 数据路由集合 - * @param stmt 执行语句 - * @param ind 第一个' '的位置 - * @return RouteResultset (数据路由集合) - * @author mycat - */ - private static RouteResultset analyseDescrSQL(SchemaConfig schema, - RouteResultset rrs, String stmt, int ind) { - - final String MATCHED_FEATURE = "DESCRIBE "; - final String MATCHED2_FEATURE = "DESC "; - int pos = 0; - while (pos < stmt.length()) { - char ch = stmt.charAt(pos); - // 忽略处理注释 /* */ BEN - if(ch == '/' && pos+4 < stmt.length() && stmt.charAt(pos+1) == '*') { - if(stmt.substring(pos+2).indexOf("*/") != -1) { - pos += stmt.substring(pos+2).indexOf("*/")+4; - continue; - } else { - // 不应该发生这类情况。 - throw new IllegalArgumentException("sql 注释 语法错误"); - } - } else if(ch == 'D'||ch == 'd') { - // 匹配 [describe ] - if(pos+MATCHED_FEATURE.length() < stmt.length() && (stmt.substring(pos).toUpperCase().indexOf(MATCHED_FEATURE) != -1)) { - pos = pos + MATCHED_FEATURE.length(); - break; - } else if(pos+MATCHED2_FEATURE.length() < stmt.length() && (stmt.substring(pos).toUpperCase().indexOf(MATCHED2_FEATURE) != -1)) { - pos = pos + MATCHED2_FEATURE.length(); - break; - } else { - pos++; - } - } - } - - // 重置ind坐标。BEN GONG - ind = pos; - int[] repPos = { ind, 0 }; - String tableName = RouterUtil.getTableName(stmt, repPos); - - stmt = stmt.substring(0, ind) + tableName + stmt.substring(repPos[1]); - RouterUtil.routeForTableMeta(rrs, schema, tableName, stmt); - return rrs; - } - - /** - * 根据执行语句判断数据路由 - * - * @param schema 数据库名 - * @param rrs 数据路由集合 - * @param stmt 执行sql - * @return RouteResultset 数据路由集合 - * @throws SQLSyntaxErrorException - * @author mycat - */ - private RouteResultset analyseDoubleAtSgin(SchemaConfig schema, - RouteResultset rrs, String stmt) throws SQLSyntaxErrorException { - String upStmt = stmt.toUpperCase(); - int atSginInd = upStmt.indexOf(" @@"); - if (atSginInd > 0) { - return RouterUtil.routeToMultiNode(false, rrs, schema.getMetaDataNodes(), stmt); - } - return RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), stmt); - } +package io.mycat.route.impl; + +import java.sql.SQLNonTransientException; +import java.sql.SQLSyntaxErrorException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.alibaba.druid.sql.SQLUtils; +import com.alibaba.druid.sql.ast.SQLObject; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.ast.expr.SQLAllExpr; +import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; +import com.alibaba.druid.sql.ast.expr.SQLExistsExpr; +import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; +import com.alibaba.druid.sql.ast.expr.SQLInSubQueryExpr; +import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; +import com.alibaba.druid.sql.ast.statement.SQLDeleteStatement; +import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; +import com.alibaba.druid.sql.ast.statement.SQLInsertStatement; +import com.alibaba.druid.sql.ast.statement.SQLSelect; +import com.alibaba.druid.sql.ast.statement.SQLSelectQuery; +import com.alibaba.druid.sql.ast.statement.SQLSelectStatement; +import com.alibaba.druid.sql.ast.statement.SQLTableSource; +import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlReplaceStatement; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; +import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; +import com.alibaba.druid.sql.parser.SQLStatementParser; +import com.alibaba.druid.stat.TableStat.Relationship; +import com.google.common.base.Strings; + +import io.mycat.MycatServer; +import io.mycat.backend.mysql.nio.handler.MiddlerQueryResultHandler; +import io.mycat.backend.mysql.nio.handler.MiddlerResultHandler; +import io.mycat.backend.mysql.nio.handler.SecondHandler; +import io.mycat.cache.LayerCachePool; +import io.mycat.config.ErrorCode; +import io.mycat.config.model.SchemaConfig; +import io.mycat.config.model.TableConfig; +import io.mycat.config.model.rule.RuleConfig; +import io.mycat.route.RouteResultset; +import io.mycat.route.RouteResultsetNode; +import io.mycat.route.function.SlotFunction; +import io.mycat.route.impl.middlerResultStrategy.BinaryOpResultHandler; +import io.mycat.route.impl.middlerResultStrategy.InSubQueryResultHandler; +import io.mycat.route.impl.middlerResultStrategy.RouteMiddlerReaultHandler; +import io.mycat.route.impl.middlerResultStrategy.SQLAllResultHandler; +import io.mycat.route.impl.middlerResultStrategy.SQLExistsResultHandler; +import io.mycat.route.impl.middlerResultStrategy.SQLQueryResultHandler; +import io.mycat.route.parser.druid.DruidParser; +import io.mycat.route.parser.druid.DruidParserFactory; +import io.mycat.route.parser.druid.DruidShardingParseInfo; +import io.mycat.route.parser.druid.MycatSchemaStatVisitor; +import io.mycat.route.parser.druid.MycatStatementParser; +import io.mycat.route.parser.druid.RouteCalculateUnit; +import io.mycat.route.parser.util.ParseUtil; +import io.mycat.route.util.RouterUtil; +import io.mycat.server.NonBlockingSession; +import io.mycat.server.ServerConnection; +import io.mycat.server.parser.ServerParse; + +public class DruidMycatRouteStrategy extends AbstractRouteStrategy { + + public static final Logger LOGGER = LoggerFactory.getLogger(DruidMycatRouteStrategy.class); + + private static Map,RouteMiddlerReaultHandler> middlerResultHandler = new HashMap<>(); + + static{ + middlerResultHandler.put(SQLQueryExpr.class, new SQLQueryResultHandler()); + middlerResultHandler.put(SQLBinaryOpExpr.class, new BinaryOpResultHandler()); + middlerResultHandler.put(SQLInSubQueryExpr.class, new InSubQueryResultHandler()); + middlerResultHandler.put(SQLExistsExpr.class, new SQLExistsResultHandler()); + middlerResultHandler.put(SQLAllExpr.class, new SQLAllResultHandler()); + } + + + @Override + public RouteResultset routeNormalSqlWithAST(SchemaConfig schema, + String stmt, RouteResultset rrs,String charset, + LayerCachePool cachePool,int sqlType,ServerConnection sc) throws SQLNonTransientException { + + /** + * 只有mysql时只支持mysql语法 + */ + SQLStatementParser parser = null; + if (schema.isNeedSupportMultiDBType()) { + parser = new MycatStatementParser(stmt); + } else { + parser = new MySqlStatementParser(stmt); + } + + MycatSchemaStatVisitor visitor = null; + SQLStatement statement; + + /** + * 解析出现问题统一抛SQL语法错误 + */ + try { + statement = parser.parseStatement(); + visitor = new MycatSchemaStatVisitor(); + } catch (Exception t) { + LOGGER.error("DruidMycatRouteStrategyError", t); + throw new SQLSyntaxErrorException(t); + } + + /** + * 检验unsupported statement + */ + checkUnSupportedStatement(statement); + + DruidParser druidParser = DruidParserFactory.create(schema, statement, visitor); + druidParser.parser(schema, rrs, statement, stmt,cachePool,visitor); + DruidShardingParseInfo ctx= druidParser.getCtx() ; + rrs.setTables(ctx.getTables()); + + if(visitor.isSubqueryRelationOr()){ + String err = "In subQuery,the or condition is not supported."; + LOGGER.error(err); + throw new SQLSyntaxErrorException(err); + } + + /* 按照以下情况路由 + 1.2.1 可以直接路由. + 1.2.2 两个表夸库join的sql.调用calat + 1.2.3 需要先执行subquery 的sql.把subquery拆分出来.获取结果后,与outerquery + */ + + //add huangyiming 分片规则不一样的且表中带查询条件的则走Catlet + List tables = ctx.getTables(); + SchemaConfig schemaConf = MycatServer.getInstance().getConfig().getSchemas().get(schema.getName()); + int index = 0; + RuleConfig firstRule = null; + boolean directRoute = true; + Set firstDataNodes = new HashSet(); + Map tconfigs = schemaConf==null?null:schemaConf.getTables(); + + Map rulemap = new HashMap<>(); + if(tconfigs!=null){ + for(String tableName : tables){ + TableConfig tc = tconfigs.get(tableName); + if(tc == null){ + //add 别名中取 + Map tableAliasMap = ctx.getTableAliasMap(); + if(tableAliasMap !=null && tableAliasMap.get(tableName) !=null){ + tc = schemaConf.getTables().get(tableAliasMap.get(tableName)); + } + } + + if(index == 0){ + if(tc !=null){ + firstRule= tc.getRule(); + //没有指定分片规则时,不做处理 + if(firstRule==null){ + continue; + } + firstDataNodes.addAll(tc.getDataNodes()); + rulemap.put(tc.getName(), firstRule); + } + }else{ + if(tc !=null){ + //ER关系表的时候是可能存在字表中没有tablerule的情况,所以加上判断 + RuleConfig ruleCfg = tc.getRule(); + if(ruleCfg==null){ //没有指定分片规则时,不做处理 + continue; + } + Set dataNodes = new HashSet(); + dataNodes.addAll(tc.getDataNodes()); + rulemap.put(tc.getName(), ruleCfg); + //如果匹配规则不相同或者分片的datanode不相同则需要走子查询处理 + if(firstRule!=null&&((ruleCfg !=null && !ruleCfg.getRuleAlgorithm().equals(firstRule.getRuleAlgorithm()) )||( !dataNodes.equals(firstDataNodes)))){ + directRoute = false; + break; + } + } + } + index++; + } + } + + RouteResultset rrsResult = rrs; + if(directRoute){ //直接路由 + if(!RouterUtil.isAllGlobalTable(ctx, schemaConf)){ + if(rulemap.size()>1&&!checkRuleField(rulemap,visitor)){ + String err = "In case of slice table,there is no rule field in the relationship condition!"; + LOGGER.error(err); + throw new SQLSyntaxErrorException(err); + } + } + rrsResult = directRoute(rrs,ctx,schema,druidParser,statement,cachePool); + }else{ + int subQuerySize = visitor.getSubQuerys().size(); + if(subQuerySize==0&&ctx.getTables().size()==2){ //两表关联,考虑使用catlet + if(!visitor.getRelationships().isEmpty()){ + rrs.setCacheAble(false); + rrs.setFinishedRoute(true); + rrsResult = catletRoute(schema,ctx.getSql(),charset,sc); + }else{ + rrsResult = directRoute(rrs,ctx,schema,druidParser,statement,cachePool); + } + }else if(subQuerySize==1){ //只涉及一张表的子查询,使用 MiddlerResultHandler 获取中间结果后,改写原有 sql 继续执行 TODO 后期可能会考虑多个子查询的情况. + SQLSelect sqlselect = visitor.getSubQuerys().iterator().next(); + if(!visitor.getRelationships().isEmpty()){ // 当 inner query 和 outer query 有关联条件时,暂不支持 + String err = "In case of slice table,sql have different rules,the relationship condition is not supported."; + LOGGER.error(err); + throw new SQLSyntaxErrorException(err); + }else{ + SQLSelectQuery sqlSelectQuery = sqlselect.getQuery(); + if(((MySqlSelectQueryBlock)sqlSelectQuery).getFrom() instanceof SQLExprTableSource) { + rrs.setCacheAble(false); + rrs.setFinishedRoute(true); + rrsResult = middlerResultRoute(schema,charset,sqlselect,sqlType,statement,sc); + } + } + }else if(subQuerySize >=2){ + String err = "In case of slice table,sql has different rules,currently only one subQuery is supported."; + LOGGER.error(err); + throw new SQLSyntaxErrorException(err); + } + } + return rrsResult; + } + + /** + * 子查询中存在关联查询的情况下,检查关联字段是否是分片字段 + * @param rulemap + * @param ships + * @return + */ + private boolean checkRuleField(Map rulemap,MycatSchemaStatVisitor visitor){ + + if(!MycatServer.getInstance().getConfig().getSystem().isSubqueryRelationshipCheck()){ + return true; + } + + Set ships = visitor.getRelationships(); + Iterator iter = ships.iterator(); + while(iter.hasNext()){ + Relationship ship = iter.next(); + String lefttable = ship.getLeft().getTable().toUpperCase(); + String righttable = ship.getRight().getTable().toUpperCase(); + // 如果是同一个表中的关联条件,不做处理 + if(lefttable.equals(righttable)){ + return true; + } + RuleConfig leftconfig = rulemap.get(lefttable); + RuleConfig rightconfig = rulemap.get(righttable); + + if(null!=leftconfig&&null!=rightconfig + &&leftconfig.equals(rightconfig) + &&leftconfig.getColumn().equals(ship.getLeft().getName().toUpperCase()) + &&rightconfig.getColumn().equals(ship.getRight().getName().toUpperCase())){ + return true; + } + } + return false; + } + + private RouteResultset middlerResultRoute(final SchemaConfig schema,final String charset,final SQLSelect sqlselect, + final int sqlType,final SQLStatement statement,final ServerConnection sc){ + + final String middlesql = SQLUtils.toMySqlString(sqlselect); + + MiddlerResultHandler middlerResultHandler = new MiddlerQueryResultHandler<>(new SecondHandler() { + @Override + public void doExecute(List param) { + sc.getSession2().setMiddlerResultHandler(null); + String sqls = null; + // 路由计算 + RouteResultset rrs = null; + try { + + sqls = buildSql(statement,sqlselect,param); + rrs = MycatServer + .getInstance() + .getRouterservice() + .route(MycatServer.getInstance().getConfig().getSystem(), + schema, sqlType,sqls.toLowerCase(), charset,sc ); + + } catch (Exception e) { + StringBuilder s = new StringBuilder(); + LOGGER.warn(s.append(this).append(sqls).toString() + " err:" + e.toString(),e); + String msg = e.getMessage(); + sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR, msg == null ? e.getClass().getSimpleName() : msg); + return; + } + NonBlockingSession noBlockSession = new NonBlockingSession(sc.getSession2().getSource()); + noBlockSession.setMiddlerResultHandler(null); + //session的预编译标示传递 + noBlockSession.setPrepared(sc.getSession2().isPrepared()); + if (rrs != null) { + noBlockSession.setCanClose(false); + noBlockSession.execute(rrs, ServerParse.SELECT); + } + } + } ); + sc.getSession2().setMiddlerResultHandler(middlerResultHandler); + sc.getSession2().setCanClose(false); + + // 路由计算 + RouteResultset rrs = null; + try { + rrs = MycatServer + .getInstance() + .getRouterservice() + .route(MycatServer.getInstance().getConfig().getSystem(), + schema, ServerParse.SELECT, middlesql, charset, sc); + + } catch (Exception e) { + StringBuilder s = new StringBuilder(); + LOGGER.warn(s.append(this).append(middlesql).toString() + " err:" + e.toString(),e); + String msg = e.getMessage(); + sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR, msg == null ? e.getClass().getSimpleName() : msg); + return null; + } + + if(rrs!=null){ + rrs.setCacheAble(false); + } + return rrs; + } + + /** + * 获取子查询执行结果后,改写原始sql 继续执行. + * @param statement + * @param sqlselect + * @param param + * @return + */ + private String buildSql(SQLStatement statement,SQLSelect sqlselect,List param){ + + SQLObject parent = sqlselect.getParent(); + RouteMiddlerReaultHandler handler = middlerResultHandler.get(parent.getClass()); + if(handler==null){ + throw new UnsupportedOperationException(parent.getClass()+" current is not supported "); + } + return handler.dohandler(statement, sqlselect, parent, param); + } + + /** + * 两个表的情况,catlet + * @param schema + * @param stmt + * @param charset + * @param sc + * @return + */ + private RouteResultset catletRoute(SchemaConfig schema,String stmt,String charset,ServerConnection sc){ + RouteResultset rrs = null; + try { + rrs = MycatServer + .getInstance() + .getRouterservice() + .route(MycatServer.getInstance().getConfig().getSystem(), + schema, ServerParse.SELECT, "/*!mycat:catlet=io.mycat.catlets.ShareJoin */ "+stmt, charset, sc); + + }catch(Exception e){ + + } + return rrs; + } + + /** + * 直接结果路由 + * @param rrs + * @param ctx + * @param schema + * @param druidParser + * @param statement + * @param cachePool + * @return + * @throws SQLNonTransientException + */ + private RouteResultset directRoute(RouteResultset rrs,DruidShardingParseInfo ctx,SchemaConfig schema, + DruidParser druidParser,SQLStatement statement,LayerCachePool cachePool) throws SQLNonTransientException{ + + //改写sql:如insert语句主键自增长, 在直接结果路由的情况下,进行sql 改写处理 + druidParser.changeSql(schema, rrs, statement,cachePool); + + /** + * DruidParser 解析过程中已完成了路由的直接返回 + */ + if ( rrs.isFinishedRoute() ) { + return rrs; + } + + /** + * 没有from的select语句或其他 + */ + if((ctx.getTables() == null || ctx.getTables().size() == 0)&&(ctx.getTableAliasMap()==null||ctx.getTableAliasMap().isEmpty())) + { + return RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), druidParser.getCtx().getSql()); + } + + if(druidParser.getCtx().getRouteCalculateUnits().size() == 0) { + RouteCalculateUnit routeCalculateUnit = new RouteCalculateUnit(); + druidParser.getCtx().addRouteCalculateUnit(routeCalculateUnit); + } + + SortedSet nodeSet = new TreeSet(); + boolean isAllGlobalTable = RouterUtil.isAllGlobalTable(ctx, schema); + for(RouteCalculateUnit unit: druidParser.getCtx().getRouteCalculateUnits()) { + RouteResultset rrsTmp = RouterUtil.tryRouteForTables(schema, druidParser.getCtx(), unit, rrs, isSelect(statement), cachePool); + if(rrsTmp != null&&rrsTmp.getNodes()!=null) { + for(RouteResultsetNode node :rrsTmp.getNodes()) { + nodeSet.add(node); + } + } + if(isAllGlobalTable) {//都是全局表时只计算一遍路由 + break; + } + } + + RouteResultsetNode[] nodes = new RouteResultsetNode[nodeSet.size()]; + int i = 0; + for (RouteResultsetNode aNodeSet : nodeSet) { + nodes[i] = aNodeSet; + if(statement instanceof MySqlInsertStatement &&ctx.getTables().size()==1&&schema.getTables().containsKey(ctx.getTables().get(0))) { + RuleConfig rule = schema.getTables().get(ctx.getTables().get(0)).getRule(); + if(rule!=null&& rule.getRuleAlgorithm() instanceof SlotFunction){ + aNodeSet.setStatement(ParseUtil.changeInsertAddSlot(aNodeSet.getStatement(),aNodeSet.getSlot())); + } + } + i++; + } + rrs.setNodes(nodes); + + //分表 + /** + * subTables="t_order$1-2,t_order3" + *目前分表 1.6 开始支持 幵丏 dataNode 在分表条件下只能配置一个,分表条件下不支持join。 + */ + if(rrs.isDistTable()){ + return this.routeDisTable(statement,rrs); + } + return rrs; + } + + private SQLExprTableSource getDisTable(SQLTableSource tableSource,RouteResultsetNode node) throws SQLSyntaxErrorException{ + if(node.getSubTableName()==null){ + String msg = " sub table not exists for " + node.getName() + " on " + tableSource; + LOGGER.error("DruidMycatRouteStrategyError " + msg); + throw new SQLSyntaxErrorException(msg); + } + + SQLIdentifierExpr sqlIdentifierExpr = new SQLIdentifierExpr(); + sqlIdentifierExpr.setParent(tableSource.getParent()); + sqlIdentifierExpr.setName(node.getSubTableName()); + SQLExprTableSource from2 = new SQLExprTableSource(sqlIdentifierExpr); + return from2; + } + + private RouteResultset routeDisTable(SQLStatement statement, RouteResultset rrs) throws SQLSyntaxErrorException{ + SQLTableSource tableSource = null; + if(statement instanceof SQLInsertStatement) { + SQLInsertStatement insertStatement = (SQLInsertStatement) statement; + tableSource = insertStatement.getTableSource(); + for (RouteResultsetNode node : rrs.getNodes()) { + SQLExprTableSource from2 = getDisTable(tableSource, node); + insertStatement.setTableSource(from2); + node.setStatement(insertStatement.toString()); + } + } + if(statement instanceof SQLDeleteStatement) { + SQLDeleteStatement deleteStatement = (SQLDeleteStatement) statement; + tableSource = deleteStatement.getTableSource(); + for (RouteResultsetNode node : rrs.getNodes()) { + SQLExprTableSource from2 = getDisTable(tableSource, node); + deleteStatement.setTableSource(from2); + node.setStatement(deleteStatement.toString()); + } + } + if(statement instanceof SQLUpdateStatement) { + SQLUpdateStatement updateStatement = (SQLUpdateStatement) statement; + tableSource = updateStatement.getTableSource(); + for (RouteResultsetNode node : rrs.getNodes()) { + SQLExprTableSource from2 = getDisTable(tableSource, node); + updateStatement.setTableSource(from2); + node.setStatement(updateStatement.toString()); + } + } + + return rrs; + } + + /** + * SELECT 语句 + */ + private boolean isSelect(SQLStatement statement) { + if(statement instanceof SQLSelectStatement) { + return true; + } + return false; + } + + /** + * 检验不支持的SQLStatement类型 :不支持的类型直接抛SQLSyntaxErrorException异常 + * @param statement + * @throws SQLSyntaxErrorException + */ + private void checkUnSupportedStatement(SQLStatement statement) throws SQLSyntaxErrorException { + //不支持replace语句 + if(statement instanceof MySqlReplaceStatement) { + throw new SQLSyntaxErrorException(" ReplaceStatement can't be supported,use insert into ...on duplicate key update... instead "); + } + } + + /** + * 分析 SHOW SQL + */ + @Override + public RouteResultset analyseShowSQL(SchemaConfig schema, + RouteResultset rrs, String stmt) throws SQLSyntaxErrorException { + + String upStmt = stmt.toUpperCase(); + int tabInd = upStmt.indexOf(" TABLES"); + if (tabInd > 0) {// show tables + int[] nextPost = RouterUtil.getSpecPos(upStmt, 0); + if (nextPost[0] > 0) {// remove db info + int end = RouterUtil.getSpecEndPos(upStmt, tabInd); + if (upStmt.indexOf(" FULL") > 0) { + stmt = "SHOW FULL TABLES" + stmt.substring(end); + } else { + stmt = "SHOW TABLES" + stmt.substring(end); + } + } + String defaultNode= schema.getDataNode(); + if(!Strings.isNullOrEmpty(defaultNode)) + { + return RouterUtil.routeToSingleNode(rrs, defaultNode, stmt); + } + return RouterUtil.routeToMultiNode(false, rrs, schema.getMetaDataNodes(), stmt); + } + + /** + * show index or column + */ + int[] indx = RouterUtil.getSpecPos(upStmt, 0); + if (indx[0] > 0) { + /** + * has table + */ + int[] repPos = { indx[0] + indx[1], 0 }; + String tableName = RouterUtil.getShowTableName(stmt, repPos); + /** + * IN DB pattern + */ + int[] indx2 = RouterUtil.getSpecPos(upStmt, indx[0] + indx[1] + 1); + if (indx2[0] > 0) {// find LIKE OR WHERE + repPos[1] = RouterUtil.getSpecEndPos(upStmt, indx2[0] + indx2[1]); + + } + stmt = stmt.substring(0, indx[0]) + " FROM " + tableName + stmt.substring(repPos[1]); + RouterUtil.routeForTableMeta(rrs, schema, tableName, stmt); + return rrs; + + } + + /** + * show create table tableName + */ + int[] createTabInd = RouterUtil.getCreateTablePos(upStmt, 0); + if (createTabInd[0] > 0) { + int tableNameIndex = createTabInd[0] + createTabInd[1]; + if (upStmt.length() > tableNameIndex) { + String tableName = stmt.substring(tableNameIndex).trim(); + int ind2 = tableName.indexOf('.'); + if (ind2 > 0) { + tableName = tableName.substring(ind2 + 1); + } + RouterUtil.routeForTableMeta(rrs, schema, tableName, stmt); + return rrs; + } + } + + return RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), stmt); + } + + +// /** +// * 为一个表进行条件路由 +// * @param schema +// * @param tablesAndConditions +// * @param tablesRouteMap +// * @throws SQLNonTransientException +// */ +// private static RouteResultset findRouteWithcConditionsForOneTable(SchemaConfig schema, RouteResultset rrs, +// Map> conditions, String tableName, String sql) throws SQLNonTransientException { +// boolean cache = rrs.isCacheAble(); +// //为分库表找路由 +// tableName = tableName.toUpperCase(); +// TableConfig tableConfig = schema.getTables().get(tableName); +// //全局表或者不分库的表略过(全局表后面再计算) +// if(tableConfig.isGlobalTable()) { +// return null; +// } else {//非全局表 +// Set routeSet = new HashSet(); +// String joinKey = tableConfig.getJoinKey(); +// for(Map.Entry> condition : conditions.entrySet()) { +// String colName = condition.getKey(); +// //条件字段是拆分字段 +// if(colName.equals(tableConfig.getPartitionColumn())) { +// Set columnPairs = condition.getValue(); +// +// for(ColumnRoutePair pair : columnPairs) { +// if(pair.colValue != null) { +// Integer nodeIndex = tableConfig.getRule().getRuleAlgorithm().calculate(pair.colValue); +// if(nodeIndex == null) { +// String msg = "can't find any valid datanode :" + tableConfig.getName() +// + " -> " + tableConfig.getPartitionColumn() + " -> " + pair.colValue; +// LOGGER.warn(msg); +// throw new SQLNonTransientException(msg); +// } +// String node = tableConfig.getDataNodes().get(nodeIndex); +// if(node != null) {//找到一个路由节点 +// routeSet.add(node); +// } +// } +// if(pair.rangeValue != null) { +// Integer[] nodeIndexs = tableConfig.getRule().getRuleAlgorithm() +// .calculateRange(pair.rangeValue.beginValue.toString(), pair.rangeValue.endValue.toString()); +// for(Integer idx : nodeIndexs) { +// String node = tableConfig.getDataNodes().get(idx); +// if(node != null) {//找到一个路由节点 +// routeSet.add(node); +// } +// } +// } +// } +// } else if(joinKey != null && joinKey.equals(colName)) { +// Set dataNodeSet = RouterUtil.ruleCalculate( +// tableConfig.getParentTC(), condition.getValue()); +// if (dataNodeSet.isEmpty()) { +// throw new SQLNonTransientException( +// "parent key can't find any valid datanode "); +// } +// if (LOGGER.isDebugEnabled()) { +// LOGGER.debug("found partion nodes (using parent partion rule directly) for child table to update " +// + Arrays.toString(dataNodeSet.toArray()) + " sql :" + sql); +// } +// if (dataNodeSet.size() > 1) { +// return RouterUtil.routeToMultiNode(rrs.isCacheAble(), rrs, schema.getAllDataNodes(), sql); +// } else { +// rrs.setCacheAble(true); +// return RouterUtil.routeToSingleNode(rrs, dataNodeSet.iterator().next(), sql); +// } +// } else {//条件字段不是拆分字段也不是join字段,略过 +// continue; +// +// } +// } +// return RouterUtil.routeToMultiNode(cache, rrs, routeSet, sql); +// +// } +// +// } + + public RouteResultset routeSystemInfo(SchemaConfig schema, int sqlType, + String stmt, RouteResultset rrs) throws SQLSyntaxErrorException { + switch(sqlType){ + case ServerParse.SHOW:// if origSQL is like show tables + return analyseShowSQL(schema, rrs, stmt); + case ServerParse.SELECT://if origSQL is like select @@ + int index = stmt.indexOf("@@"); + if(index > 0 && "SELECT".equals(stmt.substring(0, index).trim().toUpperCase())){ + return analyseDoubleAtSgin(schema, rrs, stmt); + } + break; + case ServerParse.DESCRIBE:// if origSQL is meta SQL, such as describe table + int ind = stmt.indexOf(' '); + stmt = stmt.trim(); + return analyseDescrSQL(schema, rrs, stmt, ind + 1); + } + return null; + } + + /** + * 对Desc语句进行分析 返回数据路由集合 + * * + * @param schema 数据库名 + * @param rrs 数据路由集合 + * @param stmt 执行语句 + * @param ind 第一个' '的位置 + * @return RouteResultset (数据路由集合) + * @author mycat + */ + private static RouteResultset analyseDescrSQL(SchemaConfig schema, + RouteResultset rrs, String stmt, int ind) { + + final String MATCHED_FEATURE = "DESCRIBE "; + final String MATCHED2_FEATURE = "DESC "; + int pos = 0; + while (pos < stmt.length()) { + char ch = stmt.charAt(pos); + // 忽略处理注释 /* */ BEN + if(ch == '/' && pos+4 < stmt.length() && stmt.charAt(pos+1) == '*') { + if(stmt.substring(pos+2).indexOf("*/") != -1) { + pos += stmt.substring(pos+2).indexOf("*/")+4; + continue; + } else { + // 不应该发生这类情况。 + throw new IllegalArgumentException("sql 注释 语法错误"); + } + } else if(ch == 'D'||ch == 'd') { + // 匹配 [describe ] + if(pos+MATCHED_FEATURE.length() < stmt.length() && (stmt.substring(pos).toUpperCase().indexOf(MATCHED_FEATURE) != -1)) { + pos = pos + MATCHED_FEATURE.length(); + break; + } else if(pos+MATCHED2_FEATURE.length() < stmt.length() && (stmt.substring(pos).toUpperCase().indexOf(MATCHED2_FEATURE) != -1)) { + pos = pos + MATCHED2_FEATURE.length(); + break; + } else { + pos++; + } + } + } + + // 重置ind坐标。BEN GONG + ind = pos; + int[] repPos = { ind, 0 }; + String tableName = RouterUtil.getTableName(stmt, repPos); + + stmt = stmt.substring(0, ind) + tableName + stmt.substring(repPos[1]); + RouterUtil.routeForTableMeta(rrs, schema, tableName, stmt); + return rrs; + } + + /** + * 根据执行语句判断数据路由 + * + * @param schema 数据库名 + * @param rrs 数据路由集合 + * @param stmt 执行sql + * @return RouteResultset 数据路由集合 + * @throws SQLSyntaxErrorException + * @author mycat + */ + private RouteResultset analyseDoubleAtSgin(SchemaConfig schema, + RouteResultset rrs, String stmt) throws SQLSyntaxErrorException { + String upStmt = stmt.toUpperCase(); + int atSginInd = upStmt.indexOf(" @@"); + if (atSginInd > 0) { + return RouterUtil.routeToMultiNode(false, rrs, schema.getMetaDataNodes(), stmt); + } + return RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), stmt); + } } \ No newline at end of file diff --git a/src/main/java/io/mycat/route/parser/druid/impl/DruidSelectParser.java b/src/main/java/io/mycat/route/parser/druid/impl/DruidSelectParser.java index 1294d7f94..66770a8db 100644 --- a/src/main/java/io/mycat/route/parser/druid/impl/DruidSelectParser.java +++ b/src/main/java/io/mycat/route/parser/druid/impl/DruidSelectParser.java @@ -1,861 +1,764 @@ -package io.mycat.route.parser.druid.impl; - -import java.sql.SQLNonTransientException; -import java.sql.SQLSyntaxErrorException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.SortedSet; -import java.util.TreeSet; - -import org.apache.logging.log4j.util.Strings; - -import com.alibaba.druid.sql.SQLUtils; -import com.alibaba.druid.sql.ast.SQLExpr; -import com.alibaba.druid.sql.ast.SQLName; -import com.alibaba.druid.sql.ast.SQLOrderingSpecification; -import com.alibaba.druid.sql.ast.SQLStatement; -import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr; -import com.alibaba.druid.sql.ast.expr.SQLAllColumnExpr; -import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; -import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator; -import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; -import com.alibaba.druid.sql.ast.expr.SQLIntegerExpr; -import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; -import com.alibaba.druid.sql.ast.expr.SQLNumericLiteralExpr; -import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr; -import com.alibaba.druid.sql.ast.expr.SQLTextLiteralExpr; -import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; -import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource; -import com.alibaba.druid.sql.ast.statement.SQLSelectGroupByClause; -import com.alibaba.druid.sql.ast.statement.SQLSelectItem; -import com.alibaba.druid.sql.ast.statement.SQLSelectOrderByItem; -import com.alibaba.druid.sql.ast.statement.SQLSelectQuery; -import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock; -import com.alibaba.druid.sql.ast.statement.SQLSelectStatement; -import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource; -import com.alibaba.druid.sql.ast.statement.SQLTableSource; -import com.alibaba.druid.sql.dialect.db2.ast.stmt.DB2SelectQueryBlock; -import com.alibaba.druid.sql.dialect.db2.visitor.DB2OutputVisitor; -import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlOrderingExpr; -import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; -import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock.Limit; -import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUnionQuery; -import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; -import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor; -import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; -import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleSelectQueryBlock; -import com.alibaba.druid.sql.dialect.oracle.visitor.OracleOutputVisitor; -import com.alibaba.druid.sql.dialect.postgresql.ast.stmt.PGSelectQueryBlock; -import com.alibaba.druid.sql.dialect.postgresql.visitor.PGOutputVisitor; -import com.alibaba.druid.sql.dialect.sqlserver.ast.SQLServerSelectQueryBlock; -import com.alibaba.druid.sql.parser.SQLStatementParser; -import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor; -import com.alibaba.druid.util.JdbcConstants; -import com.alibaba.druid.wall.spi.WallVisitorUtils; - -import io.mycat.MycatServer; -import io.mycat.cache.LayerCachePool; -import io.mycat.config.ErrorCode; -import io.mycat.config.model.SchemaConfig; -import io.mycat.config.model.TableConfig; -import io.mycat.route.RouteResultset; -import io.mycat.route.RouteResultsetNode; -import io.mycat.route.parser.druid.MycatSchemaStatVisitor; -import io.mycat.route.parser.druid.MycatStatementParser; -import io.mycat.route.parser.druid.RouteCalculateUnit; -import io.mycat.route.parser.util.WildcardUtil; -import io.mycat.route.util.RouterUtil; -import io.mycat.sqlengine.mpp.ColumnRoutePair; -import io.mycat.sqlengine.mpp.HavingCols; -import io.mycat.sqlengine.mpp.MergeCol; -import io.mycat.sqlengine.mpp.OrderCol; -import io.mycat.util.ObjectUtil; -import io.mycat.util.StringUtil; - -public class DruidSelectParser extends DefaultDruidParser { - - protected boolean isNeedParseOrderAgg=true; - - @Override - public void statementParse(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt) { - SQLSelectStatement selectStmt = (SQLSelectStatement)stmt; - SQLSelectQuery sqlSelectQuery = selectStmt.getSelect().getQuery(); - if(sqlSelectQuery instanceof MySqlSelectQueryBlock) { - MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)selectStmt.getSelect().getQuery(); - - parseOrderAggGroupMysql(schema, stmt,rrs, mysqlSelectQuery); - //更改canRunInReadDB属性 - if ((mysqlSelectQuery.isForUpdate() || mysqlSelectQuery.isLockInShareMode()) && rrs.isAutocommit() == false) - { - rrs.setCanRunInReadDB(false); - } - - } else if (sqlSelectQuery instanceof MySqlUnionQuery) { -// MySqlUnionQuery unionQuery = (MySqlUnionQuery)sqlSelectQuery; -// MySqlSelectQueryBlock left = (MySqlSelectQueryBlock)unionQuery.getLeft(); -// MySqlSelectQueryBlock right = (MySqlSelectQueryBlock)unionQuery.getLeft(); -// System.out.println(); - } - } - protected void parseOrderAggGroupMysql(SchemaConfig schema, SQLStatement stmt, RouteResultset rrs, MySqlSelectQueryBlock mysqlSelectQuery) - { - MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); - stmt.accept(visitor); -// rrs.setGroupByCols((String[])visitor.getGroupByColumns().toArray()); - if(!isNeedParseOrderAgg) - { - return; - } - Map aliaColumns = parseAggGroupCommon(schema, stmt, rrs, mysqlSelectQuery); - - //setOrderByCols - if(mysqlSelectQuery.getOrderBy() != null) { - List orderByItems = mysqlSelectQuery.getOrderBy().getItems(); - rrs.setOrderByCols(buildOrderByCols(orderByItems,aliaColumns)); - } - isNeedParseOrderAgg=false; - } - protected Map parseAggGroupCommon(SchemaConfig schema, SQLStatement stmt, RouteResultset rrs, SQLSelectQueryBlock mysqlSelectQuery) - { - Map aliaColumns = new HashMap(); - Map aggrColumns = new HashMap(); - // Added by winbill, 20160314, for having clause, Begin ==> - List havingColsName = new ArrayList(); - // Added by winbill, 20160314, for having clause, End <== - List selectList = mysqlSelectQuery.getSelectList(); - boolean isNeedChangeSql=false; - int size = selectList.size(); - boolean isDistinct=mysqlSelectQuery.getDistionOption()==2; - for (int i = 0; i < size; i++) - { - SQLSelectItem item = selectList.get(i); - - if (item.getExpr() instanceof SQLAggregateExpr) - { - SQLAggregateExpr expr = (SQLAggregateExpr) item.getExpr(); - String method = expr.getMethodName(); - boolean isHasArgument=!expr.getArguments().isEmpty(); - if(isHasArgument) - { - String aggrColName = method + "(" + expr.getArguments().get(0) + ")"; // Added by winbill, 20160314, for having clause - havingColsName.add(aggrColName); // Added by winbill, 20160314, for having clause - } - //只处理有别名的情况,无别名添加别名,否则某些数据库会得不到正确结果处理 - int mergeType = MergeCol.getMergeType(method); - if (MergeCol.MERGE_AVG == mergeType&&isRoutMultiNode(schema,rrs)) - { //跨分片avg需要特殊处理,直接avg结果是不对的 - String colName = item.getAlias() != null ? item.getAlias() : method + i; - SQLSelectItem sum =new SQLSelectItem(); - String sumColName = colName + "SUM"; - sum.setAlias(sumColName); - SQLAggregateExpr sumExp =new SQLAggregateExpr("SUM"); - ObjectUtil.copyProperties(expr,sumExp); - sumExp.getArguments().addAll(expr.getArguments()); - sumExp.setMethodName("SUM"); - sum.setExpr(sumExp); - selectList.set(i, sum); - aggrColumns.put(sumColName, MergeCol.MERGE_SUM); - havingColsName.add(sumColName); // Added by winbill, 20160314, for having clause - havingColsName.add(item.getAlias() != null ? item.getAlias() : ""); // Added by winbill, 20160314, two aliases for AVG - - SQLSelectItem count =new SQLSelectItem(); - String countColName = colName + "COUNT"; - count.setAlias(countColName); - SQLAggregateExpr countExp = new SQLAggregateExpr("COUNT"); - ObjectUtil.copyProperties(expr,countExp); - countExp.getArguments().addAll(expr.getArguments()); - countExp.setMethodName("COUNT"); - count.setExpr(countExp); - selectList.add(count); - aggrColumns.put(countColName, MergeCol.MERGE_COUNT); - - isNeedChangeSql=true; - aggrColumns.put(colName, mergeType); - rrs.setHasAggrColumn(true); - } else if (MergeCol.MERGE_UNSUPPORT != mergeType){ - String aggColName = null; - StringBuilder sb = new StringBuilder(); - if(mysqlSelectQuery instanceof MySqlSelectQueryBlock) { - expr.accept(new MySqlOutputVisitor(sb)); - } else if(mysqlSelectQuery instanceof OracleSelectQueryBlock) { - expr.accept(new OracleOutputVisitor(sb)); - } else if(mysqlSelectQuery instanceof PGSelectQueryBlock){ - expr.accept(new PGOutputVisitor(sb)); - } else if(mysqlSelectQuery instanceof SQLServerSelectQueryBlock) { - expr.accept(new SQLASTOutputVisitor(sb)); - } else if(mysqlSelectQuery instanceof DB2SelectQueryBlock) { - expr.accept(new DB2OutputVisitor(sb)); - } - aggColName = sb.toString(); - - if (item.getAlias() != null && item.getAlias().length() > 0) - { - aggrColumns.put(item.getAlias(), mergeType); - aliaColumns.put(aggColName,item.getAlias()); - } else - { //如果不加,jdbc方式时取不到正确结果 ;修改添加别名 - item.setAlias(method + i); - aggrColumns.put(method + i, mergeType); - aliaColumns.put(aggColName, method + i); - isNeedChangeSql=true; - } - rrs.setHasAggrColumn(true); - havingColsName.add(item.getAlias()); // Added by winbill, 20160314, for having clause - havingColsName.add(""); // Added by winbill, 20160314, one alias for non-AVG - } - } else - { - if (!(item.getExpr() instanceof SQLAllColumnExpr)) - { - String alia = item.getAlias(); - String field = getFieldName(item); - if (alia == null) - { - alia = field; - } - aliaColumns.put(field, alia); - } - } - - } - if(aggrColumns.size() > 0) { - rrs.setMergeCols(aggrColumns); - } - - //通过优化转换成group by来实现 - if(isDistinct) - { - mysqlSelectQuery.setDistionOption(0); - SQLSelectGroupByClause groupBy=new SQLSelectGroupByClause(); - for (String fieldName : aliaColumns.keySet()) - { - groupBy.addItem(new SQLIdentifierExpr(fieldName)); - } - mysqlSelectQuery.setGroupBy(groupBy); - isNeedChangeSql=true; - } - - - //setGroupByCols - if(mysqlSelectQuery.getGroupBy() != null) { - List groupByItems = mysqlSelectQuery.getGroupBy().getItems(); - String[] groupByCols = buildGroupByCols(groupByItems,aliaColumns); - WildcardUtil.wildcards(groupByCols); - rrs.setGroupByCols(groupByCols); - rrs.setHavings(buildGroupByHaving(mysqlSelectQuery.getGroupBy().getHaving(),aliaColumns)); - rrs.setHasAggrColumn(true); - rrs.setHavingColsName(havingColsName.toArray()); // Added by winbill, 20160314, for having clause - } - - - if (isNeedChangeSql) - { - String sql = stmt.toString(); - rrs.changeNodeSqlAfterAddLimit(schema,getCurentDbType(),sql,0,-1, false); - getCtx().setSql(sql); - } - return aliaColumns; - } - - private HavingCols buildGroupByHaving(SQLExpr having,Map aliaColumns ){ - if (having == null) { - return null; - } - - SQLBinaryOpExpr expr = ((SQLBinaryOpExpr) having); - SQLExpr left = expr.getLeft(); - SQLBinaryOperator operator = expr.getOperator(); - SQLExpr right = expr.getRight(); - - String leftValue = null;; - if (left instanceof SQLAggregateExpr) { - leftValue = ((SQLAggregateExpr) left).getMethodName() + "(" - + ((SQLAggregateExpr) left).getArguments().get(0) + ")"; - String aggrColumnAlias = getAliaColumn(aliaColumns,leftValue); - if(aggrColumnAlias != null) { // having聚合函数存在别名 - expr.setLeft(new SQLIdentifierExpr(aggrColumnAlias)); - leftValue = aggrColumnAlias; - } - } else if (left instanceof SQLIdentifierExpr) { - leftValue = ((SQLIdentifierExpr) left).getName(); - } - - String rightValue = null; - if (right instanceof SQLNumericLiteralExpr) { - rightValue = right.toString(); - }else if(right instanceof SQLTextLiteralExpr){ - rightValue = StringUtil.removeBackquote(right.toString()); - } - - return new HavingCols(leftValue,rightValue,operator.getName()); - } - - private boolean isRoutMultiNode(SchemaConfig schema, RouteResultset rrs) - { - if(rrs.getNodes()!=null&&rrs.getNodes().length>1) - { - return true; - } - LayerCachePool tableId2DataNodeCache = (LayerCachePool) MycatServer.getInstance().getCacheService().getCachePool("TableID2DataNodeCache"); - try - { - tryRoute(schema, rrs, tableId2DataNodeCache); - if(rrs.getNodes()!=null&&rrs.getNodes().length>1) - { - return true; - } - } catch (SQLNonTransientException e) - { - throw new RuntimeException(e); - } - return false; - } - - private String getFieldName(SQLSelectItem item){ - if ((item.getExpr() instanceof SQLPropertyExpr)||(item.getExpr() instanceof SQLMethodInvokeExpr) - || (item.getExpr() instanceof SQLIdentifierExpr) || item.getExpr() instanceof SQLBinaryOpExpr) { - return item.getExpr().toString();//字段别名 - } - else { - return item.toString(); - } - } - - /** - * 现阶段目标为 有一个只涉及到一张表的子查询时,先执行子查询,获得返回结果后,改写原有sql继续执行,得到最终结果. - * 在这种情况下,原sql不需要继续解析. - * 使用catlet 的情况也不再继续解析. - */ - @Override - public boolean afterVisitorParser(RouteResultset rrs, SQLStatement stmt, MycatSchemaStatVisitor visitor) { - int subQuerySize = visitor.getSubQuerys().size(); - - if(subQuerySize==0&&ctx.getTables().size()==2){ //两表关联,考虑使用catlet - if(ctx.getVisitor().getConditions() !=null && ctx.getVisitor().getConditions().size()>0){ - return true; - } - }else if(subQuerySize==1){ //只涉及一张表的子查询,使用 MiddlerResultHandler 获取中间结果后,改写原有 sql 继续执行 TODO 后期可能会考虑多个. - SQLSelectQuery sqlSelectQuery = visitor.getSubQuerys().iterator().next().getQuery(); - if(((MySqlSelectQueryBlock)sqlSelectQuery).getFrom() instanceof SQLExprTableSource) { - return true; - } - } - - return super.afterVisitorParser(rrs, stmt, visitor); - } - - /** - * 改写sql:需要加limit的加上 - */ - @Override - public void changeSql(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt,LayerCachePool cachePool) throws SQLNonTransientException { - - tryRoute(schema, rrs, cachePool); - - rrs.copyLimitToNodes(); - - SQLSelectStatement selectStmt = (SQLSelectStatement)stmt; - SQLSelectQuery sqlSelectQuery = selectStmt.getSelect().getQuery(); - if(sqlSelectQuery instanceof MySqlSelectQueryBlock) { - MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)selectStmt.getSelect().getQuery(); - int limitStart = 0; - int limitSize = schema.getDefaultMaxLimit(); - - //clear group having - SQLSelectGroupByClause groupByClause = mysqlSelectQuery.getGroupBy(); - // Modified by winbill, 20160614, do NOT include having clause when routing to multiple nodes - if(groupByClause != null && groupByClause.getHaving() != null && isRoutMultiNode(schema,rrs)){ - groupByClause.setHaving(null); - } - - Map>> allConditions = getAllConditions(); - boolean isNeedAddLimit = isNeedAddLimit(schema, rrs, mysqlSelectQuery, allConditions); - if(isNeedAddLimit) { - Limit limit = new Limit(); - limit.setRowCount(new SQLIntegerExpr(limitSize)); - mysqlSelectQuery.setLimit(limit); - rrs.setLimitSize(limitSize); - String sql= getSql(rrs, stmt, isNeedAddLimit); - rrs.changeNodeSqlAfterAddLimit(schema, getCurentDbType(), sql, 0, limitSize, true); - - } - Limit limit = mysqlSelectQuery.getLimit(); - if(limit != null&&!isNeedAddLimit) { - SQLIntegerExpr offset = (SQLIntegerExpr)limit.getOffset(); - SQLIntegerExpr count = (SQLIntegerExpr)limit.getRowCount(); - if(offset != null) { - limitStart = offset.getNumber().intValue(); - rrs.setLimitStart(limitStart); - } - if(count != null) { - limitSize = count.getNumber().intValue(); - rrs.setLimitSize(limitSize); - } - - if(isNeedChangeLimit(rrs)) { - Limit changedLimit = new Limit(); - changedLimit.setRowCount(new SQLIntegerExpr(limitStart + limitSize)); - - if(offset != null) { - if(limitStart < 0) { - String msg = "You have an error in your SQL syntax; check the manual that " + - "corresponds to your MySQL server version for the right syntax to use near '" + limitStart + "'"; - throw new SQLNonTransientException(ErrorCode.ER_PARSE_ERROR + " - " + msg); - } else { - changedLimit.setOffset(new SQLIntegerExpr(0)); - - } - } - - mysqlSelectQuery.setLimit(changedLimit); - - String sql= getSql(rrs, stmt, isNeedAddLimit); - rrs.changeNodeSqlAfterAddLimit(schema,getCurentDbType(),sql,0, limitStart + limitSize, true); - - //设置改写后的sql - ctx.setSql(sql); - - } else - { - - rrs.changeNodeSqlAfterAddLimit(schema,getCurentDbType(),getCtx().getSql(),rrs.getLimitStart(), rrs.getLimitSize(), true); - // ctx.setSql(nativeSql); - - } - - - } - - if(rrs.isDistTable()){ - - for (RouteResultsetNode node : rrs.getNodes()) { - String sql = selectStmt.toString(); - SQLStatementParser parser = null; - if (schema.isNeedSupportMultiDBType()) { - parser = new MycatStatementParser(sql); - } else { - parser = new MySqlStatementParser(sql); - } - - SQLSelectStatement newStmt = null; - try { - newStmt = (SQLSelectStatement) parser.parseStatement(); - } catch (Exception t) { - LOGGER.error("DruidMycatRouteStrategyError", t); - throw new SQLSyntaxErrorException(t); - } - - MySqlSelectQueryBlock query = (MySqlSelectQueryBlock) newStmt.getSelect().getQuery(); - SQLTableSource from2 = query.getFrom(); - if (from2 instanceof SQLSubqueryTableSource) { - SQLSubqueryTableSource from = (SQLSubqueryTableSource) from2; - MySqlSelectQueryBlock query2 = (MySqlSelectQueryBlock) from.getSelect().getQuery(); - repairExpr(query2.getFrom(), node); - node.setStatement(newStmt.toString()); - } else { - repairExpr(from2, node); - node.setStatement(newStmt.toString()); - } - } - - } - - rrs.setCacheAble(isNeedCache(schema, rrs, mysqlSelectQuery, allConditions)); - } - - } - /*private void sqlRoute(RouteResultset rrs, SQLSelectStatement selectStmt, MySqlSelectQueryBlock mysqlSelectQuery, - MySqlSelectQueryBlock query) throws SQLNonTransientException { - SQLTableSource from2 = query.getFrom(); - SQLExprTableSource left2 = (SQLExprTableSource) getExpr(from2); - String alias = left2.getAlias(); - - left2.setAlias(alias); - - SQLTableSource from1 = mysqlSelectQuery.getFrom(); - - for (RouteResultsetNode node : rrs.getNodes()) { - SQLSelectStatement newStmt = new SQLSelectStatement(); - try { - BeanUtils.copyProperties(selectStmt, newStmt); - } catch (Exception e) { - throw new SQLNonTransientException(ErrorCode.ER_PARSE_ERROR + " - " + "copy exception"); - } - - repairExpr(from2, node); - //mysqlSelectQuery.setFrom(left2); - node.setStatement(selectStmt.toString()); - } - }*/ - - private void repairExpr(SQLTableSource source, RouteResultsetNode node) throws SQLNonTransientException { - Map subTableNames = node.getSubTableNames(); - - if (source instanceof SQLJoinTableSource) { - SQLJoinTableSource joinsource = (SQLJoinTableSource) source; - SQLTableSource right = joinsource.getRight(); - String alias = right.getAlias(); - - if(right instanceof SQLExprTableSource) { - SQLIdentifierExpr expr = (SQLIdentifierExpr) ((SQLExprTableSource) right).getExpr(); - alias = Strings.isBlank(alias) ? expr.getName() : alias; - String subTableName = subTableNames.get(WildcardUtil.wildcard(expr.getName().toUpperCase())); - if (Strings.isNotBlank(subTableName)) { - String tableName = expr.getName(); - alias = Strings.isBlank(alias) ? tableName : alias; - right.setAlias(alias); - - expr.setName(subTableName); - } - } - - if (right instanceof SQLSubqueryTableSource) { - SQLSubqueryTableSource from = (SQLSubqueryTableSource) right; - MySqlSelectQueryBlock query2 = (MySqlSelectQueryBlock) from.getSelect().getQuery(); - repairExpr(query2.getFrom(), node); - } - - repairExpr(joinsource.getLeft(), node); - } else if (source instanceof SQLExprTableSource) { - SQLExprTableSource exprTableSource = (SQLExprTableSource) source; - String alias = exprTableSource.getAlias(); - - SQLIdentifierExpr expr = (SQLIdentifierExpr) exprTableSource.getExpr(); - alias = Strings.isBlank(alias) ? expr.getName() : alias; - String subTableName = subTableNames.get(WildcardUtil.wildcard(expr.getName().toUpperCase())); - if (Strings.isNotBlank(subTableName)) { - String tableName = expr.getName(); - alias = Strings.isBlank(alias) ? tableName : alias; - exprTableSource.setAlias(alias); - expr.setName(subTableName); - } - - } - } - - /** - * 获取所有的条件:因为可能被or语句拆分成多个RouteCalculateUnit,条件分散了 - * @return - */ - private Map>> getAllConditions() { - Map>> map = new HashMap>>(); - for(RouteCalculateUnit unit : ctx.getRouteCalculateUnits()) { - if(unit != null && unit.getTablesAndConditions() != null) { - map.putAll(unit.getTablesAndConditions()); - } - } - - return map; - } - - private void tryRoute(SchemaConfig schema, RouteResultset rrs, LayerCachePool cachePool) throws SQLNonTransientException { - if(rrs.isFinishedRoute()) - { - return;//避免重复路由 - } - - //无表的select语句直接路由带任一节点 - if((ctx.getTables() == null || ctx.getTables().size() == 0)&&(ctx.getTableAliasMap()==null||ctx.getTableAliasMap().isEmpty())) { - rrs = RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), ctx.getSql()); - rrs.setFinishedRoute(true); - return; - } -// RouterUtil.tryRouteForTables(schema, ctx, rrs, true, cachePool); - SortedSet nodeSet = new TreeSet(); - boolean isAllGlobalTable = RouterUtil.isAllGlobalTable(ctx, schema); - for (RouteCalculateUnit unit : ctx.getRouteCalculateUnits()) { - RouteResultset rrsTmp = RouterUtil.tryRouteForTables(schema, ctx, unit, rrs, true, cachePool); - if (rrsTmp != null&&rrsTmp.getNodes()!=null) { - for (RouteResultsetNode node : rrsTmp.getNodes()) { - nodeSet.add(node); - } - } - if(isAllGlobalTable) {//都是全局表时只计算一遍路由 - break; - } - } - - if(nodeSet.size() == 0) { - - Collection stringCollection= ctx.getTableAliasMap().values() ; - for (String table : stringCollection) - { - if(table!=null&&table.toLowerCase().contains("information_schema.")) - { - rrs = RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), ctx.getSql()); - rrs.setFinishedRoute(true); - return; - } - } - String msg = " find no Route:" + ctx.getSql(); - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - - RouteResultsetNode[] nodes = new RouteResultsetNode[nodeSet.size()]; - int i = 0; - for (Iterator iterator = nodeSet.iterator(); iterator.hasNext();) { - nodes[i] = (RouteResultsetNode) iterator.next(); - i++; - - } - - rrs.setNodes(nodes); - rrs.setFinishedRoute(true); - } - - - protected String getCurentDbType() - { - return JdbcConstants.MYSQL; - } - - - - - protected String getSql( RouteResultset rrs,SQLStatement stmt, boolean isNeedAddLimit) - { - if(getCurentDbType().equalsIgnoreCase("mysql")&&(isNeedChangeLimit(rrs)||isNeedAddLimit)) - { - - return stmt.toString(); - - } - - return getCtx().getSql(); - } - - - - protected boolean isNeedChangeLimit(RouteResultset rrs) { - if(rrs.getNodes() == null) { - return false; - } else { - if(rrs.getNodes().length > 1) { - return true; - } - return false; - - } - } - - private boolean isNeedCache(SchemaConfig schema, RouteResultset rrs, - MySqlSelectQueryBlock mysqlSelectQuery, Map>> allConditions) { - if(ctx.getTables() == null || ctx.getTables().size() == 0 ) { - return false; - } - TableConfig tc = schema.getTables().get(ctx.getTables().get(0)); - if(tc==null ||(ctx.getTables().size() == 1 && tc.isGlobalTable()) - ) {//|| (ctx.getTables().size() == 1) && tc.getRule() == null && tc.getDataNodes().size() == 1 - return false; - } else { - //单表主键查询 - if(ctx.getTables().size() == 1) { - String tableName = ctx.getTables().get(0); - String primaryKey = schema.getTables().get(tableName).getPrimaryKey(); -// schema.getTables().get(ctx.getTables().get(0)).getParentKey() != null; - if(ctx.getRouteCalculateUnit().getTablesAndConditions().get(tableName) != null - && ctx.getRouteCalculateUnit().getTablesAndConditions().get(tableName).get(primaryKey) != null - && tc.getDataNodes().size() > 1) {//有主键条件 - return false; - } - //全局表不缓存 - }else if(RouterUtil.isAllGlobalTable(ctx, schema)){ - return false; - } - return true; - } - } - - /** - * 单表且是全局表 - * 单表且rule为空且nodeNodes只有一个 - * @param schema - * @param rrs - * @param mysqlSelectQuery - * @return - */ - private boolean isNeedAddLimit(SchemaConfig schema, RouteResultset rrs, - MySqlSelectQueryBlock mysqlSelectQuery, Map>> allConditions) { -// ctx.getTablesAndConditions().get(key)) - if(rrs.getLimitSize()>-1) - { - return false; - }else - if(schema.getDefaultMaxLimit() == -1) { - return false; - } else if (mysqlSelectQuery.getLimit() != null) {//语句中已有limit - return false; - } else if(ctx.getTables().size() == 1) { - String tableName = ctx.getTables().get(0); - TableConfig tableConfig = schema.getTables().get(tableName); - if(tableConfig==null) - { - return schema.getDefaultMaxLimit() > -1; // 找不到则取schema的配置 - } - - boolean isNeedAddLimit= tableConfig.isNeedAddLimit(); - if(!isNeedAddLimit) - { - return false;//优先从配置文件取 - } - - if(schema.getTables().get(tableName).isGlobalTable()) { - return true; - } - - String primaryKey = schema.getTables().get(tableName).getPrimaryKey(); - -// schema.getTables().get(ctx.getTables().get(0)).getParentKey() != null; - if(allConditions.get(tableName) == null) {//无条件 - return true; - } - - if (allConditions.get(tableName).get(primaryKey) != null) {//条件中带主键 - return false; - } - - return true; - } else if(rrs.hasPrimaryKeyToCache() && ctx.getTables().size() == 1){//只有一个表且条件中有主键,不需要limit了,因为主键只能查到一条记录 - return false; - } else {//多表或无表 - return false; - } - - } - private String getAliaColumn(Map aliaColumns,String column ){ - String alia=aliaColumns.get(column); - if (alia==null){ - if(column.indexOf(".") < 0) { - String col = "." + column; - String col2 = ".`" + column+"`"; - //展开aliaColumns,将之类的键值对展开成 - for(Map.Entry entry : aliaColumns.entrySet()) { - if(entry.getKey().endsWith(col)||entry.getKey().endsWith(col2)) { - if(entry.getValue() != null && entry.getValue().indexOf(".") > 0) { - return column; - } - return entry.getValue(); - } - } - } - - return column; - } - else { - return alia; - } - } - - private String[] buildGroupByCols(List groupByItems,Map aliaColumns) { - String[] groupByCols = new String[groupByItems.size()]; - for(int i= 0; i < groupByItems.size(); i++) { - SQLExpr sqlExpr = groupByItems.get(i); - String column = null; - if(sqlExpr instanceof SQLIdentifierExpr ) - { - column=((SQLIdentifierExpr) sqlExpr).getName(); - } else if(sqlExpr instanceof SQLMethodInvokeExpr){ - column = ((SQLMethodInvokeExpr) sqlExpr).toString(); - } else if(sqlExpr instanceof MySqlOrderingExpr){ - //todo czn - SQLExpr expr = ((MySqlOrderingExpr) sqlExpr).getExpr(); - - if (expr instanceof SQLName) - { - column = StringUtil.removeBackquote(((SQLName) expr).getSimpleName());//不要转大写 2015-2-10 sohudo StringUtil.removeBackquote(expr.getSimpleName().toUpperCase()); - } else - { - column = StringUtil.removeBackquote(expr.toString()); - } - } else if(sqlExpr instanceof SQLPropertyExpr){ - /** - * 针对子查询别名,例如select id from (select h.id from hotnews h union select h.title from hotnews h ) as t1 group by t1.id; - */ - column = sqlExpr.toString(); - } - if(column == null){ - column = sqlExpr.toString(); - } - int dotIndex=column.indexOf(".") ; - int bracketIndex=column.indexOf("(") ; - //通过判断含有括号来决定是否为函数列 - if(dotIndex!=-1&&bracketIndex==-1) - { - //此步骤得到的column必须是不带.的,有别名的用别名,无别名的用字段名 - column=column.substring(dotIndex+1) ; - } - groupByCols[i] = getAliaColumn(aliaColumns,column);//column; - } - return groupByCols; - } - - protected LinkedHashMap buildOrderByCols(List orderByItems,Map aliaColumns) { - LinkedHashMap map = new LinkedHashMap(); - for(int i= 0; i < orderByItems.size(); i++) { - SQLOrderingSpecification type = orderByItems.get(i).getType(); - //orderColumn只记录字段名称,因为返回的结果集是不带表名的。 - SQLExpr expr = orderByItems.get(i).getExpr(); - String col; - if (expr instanceof SQLName) { - col = ((SQLName)expr).getSimpleName(); - } - else { - col =expr.toString(); - } - if(type == null) { - type = SQLOrderingSpecification.ASC; - } - col=getAliaColumn(aliaColumns,col);//此步骤得到的col必须是不带.的,有别名的用别名,无别名的用字段名 - map.put(col, type == SQLOrderingSpecification.ASC ? OrderCol.COL_ORDER_TYPE_ASC : OrderCol.COL_ORDER_TYPE_DESC); - } - return map; - } - - private boolean isConditionAlwaysTrue(SQLStatement statement) { - SQLSelectStatement selectStmt = (SQLSelectStatement)statement; - SQLSelectQuery sqlSelectQuery = selectStmt.getSelect().getQuery(); - if(sqlSelectQuery instanceof MySqlSelectQueryBlock) { - MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)selectStmt.getSelect().getQuery(); - SQLExpr expr = mysqlSelectQuery.getWhere(); - - Object o = WallVisitorUtils.getValue(expr); - if(Boolean.TRUE.equals(o)) { - return true; - } - return false; - } else {//union - return false; - } - - } - - protected void setLimitIFChange(SQLStatement stmt, RouteResultset rrs, SchemaConfig schema, SQLBinaryOpExpr one, int firstrownum, int lastrownum) - { - rrs.setLimitStart(firstrownum); - rrs.setLimitSize(lastrownum - firstrownum); - LayerCachePool tableId2DataNodeCache = (LayerCachePool) MycatServer.getInstance().getCacheService().getCachePool("TableID2DataNodeCache"); - try - { - tryRoute(schema, rrs, tableId2DataNodeCache); - } catch (SQLNonTransientException e) - { - throw new RuntimeException(e); - } - if (isNeedChangeLimit(rrs)) - { - one.setRight(new SQLIntegerExpr(0)); - String curentDbType ="db2".equalsIgnoreCase(this.getCurentDbType())?"oracle":getCurentDbType(); - String sql = SQLUtils.toSQLString(stmt, curentDbType);; - rrs.changeNodeSqlAfterAddLimit(schema,getCurentDbType(), sql,0,lastrownum, false); - //设置改写后的sql - getCtx().setSql(sql); - } - } -} +package io.mycat.route.parser.druid.impl; + +import java.sql.SQLNonTransientException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; + +import com.alibaba.druid.sql.SQLUtils; +import com.alibaba.druid.sql.ast.SQLExpr; +import com.alibaba.druid.sql.ast.SQLName; +import com.alibaba.druid.sql.ast.SQLOrderingSpecification; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr; +import com.alibaba.druid.sql.ast.expr.SQLAllColumnExpr; +import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; +import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator; +import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; +import com.alibaba.druid.sql.ast.expr.SQLIntegerExpr; +import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; +import com.alibaba.druid.sql.ast.expr.SQLNumericLiteralExpr; +import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr; +import com.alibaba.druid.sql.ast.expr.SQLTextLiteralExpr; +import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; +import com.alibaba.druid.sql.ast.statement.SQLSelectGroupByClause; +import com.alibaba.druid.sql.ast.statement.SQLSelectItem; +import com.alibaba.druid.sql.ast.statement.SQLSelectOrderByItem; +import com.alibaba.druid.sql.ast.statement.SQLSelectQuery; +import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock; +import com.alibaba.druid.sql.ast.statement.SQLSelectStatement; +import com.alibaba.druid.sql.ast.statement.SQLTableSource; +import com.alibaba.druid.sql.dialect.db2.ast.stmt.DB2SelectQueryBlock; +import com.alibaba.druid.sql.dialect.db2.visitor.DB2OutputVisitor; +import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlOrderingExpr; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock.Limit; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUnionQuery; +import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor; +import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; +import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleSelectQueryBlock; +import com.alibaba.druid.sql.dialect.oracle.visitor.OracleOutputVisitor; +import com.alibaba.druid.sql.dialect.postgresql.ast.stmt.PGSelectQueryBlock; +import com.alibaba.druid.sql.dialect.postgresql.visitor.PGOutputVisitor; +import com.alibaba.druid.sql.dialect.sqlserver.ast.SQLServerSelectQueryBlock; +import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor; +import com.alibaba.druid.util.JdbcConstants; +import com.alibaba.druid.wall.spi.WallVisitorUtils; + +import io.mycat.MycatServer; +import io.mycat.cache.LayerCachePool; +import io.mycat.config.ErrorCode; +import io.mycat.config.model.SchemaConfig; +import io.mycat.config.model.TableConfig; +import io.mycat.route.RouteResultset; +import io.mycat.route.RouteResultsetNode; +import io.mycat.route.parser.druid.MycatSchemaStatVisitor; +import io.mycat.route.parser.druid.RouteCalculateUnit; +import io.mycat.route.util.RouterUtil; +import io.mycat.sqlengine.mpp.ColumnRoutePair; +import io.mycat.sqlengine.mpp.HavingCols; +import io.mycat.sqlengine.mpp.MergeCol; +import io.mycat.sqlengine.mpp.OrderCol; +import io.mycat.util.ObjectUtil; +import io.mycat.util.StringUtil; + +public class DruidSelectParser extends DefaultDruidParser { + + + protected boolean isNeedParseOrderAgg=true; + + @Override + public void statementParse(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt) { + SQLSelectStatement selectStmt = (SQLSelectStatement)stmt; + SQLSelectQuery sqlSelectQuery = selectStmt.getSelect().getQuery(); + if(sqlSelectQuery instanceof MySqlSelectQueryBlock) { + MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)selectStmt.getSelect().getQuery(); + + parseOrderAggGroupMysql(schema, stmt,rrs, mysqlSelectQuery); + //更改canRunInReadDB属性 + if ((mysqlSelectQuery.isForUpdate() || mysqlSelectQuery.isLockInShareMode()) && rrs.isAutocommit() == false) + { + rrs.setCanRunInReadDB(false); + } + + } else if (sqlSelectQuery instanceof MySqlUnionQuery) { +// MySqlUnionQuery unionQuery = (MySqlUnionQuery)sqlSelectQuery; +// MySqlSelectQueryBlock left = (MySqlSelectQueryBlock)unionQuery.getLeft(); +// MySqlSelectQueryBlock right = (MySqlSelectQueryBlock)unionQuery.getLeft(); +// System.out.println(); + } + } + protected void parseOrderAggGroupMysql(SchemaConfig schema, SQLStatement stmt, RouteResultset rrs, MySqlSelectQueryBlock mysqlSelectQuery) + { + MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); + stmt.accept(visitor); +// rrs.setGroupByCols((String[])visitor.getGroupByColumns().toArray()); + if(!isNeedParseOrderAgg) + { + return; + } + Map aliaColumns = parseAggGroupCommon(schema, stmt, rrs, mysqlSelectQuery); + + //setOrderByCols + if(mysqlSelectQuery.getOrderBy() != null) { + List orderByItems = mysqlSelectQuery.getOrderBy().getItems(); + rrs.setOrderByCols(buildOrderByCols(orderByItems,aliaColumns)); + } + isNeedParseOrderAgg=false; + } + protected Map parseAggGroupCommon(SchemaConfig schema, SQLStatement stmt, RouteResultset rrs, SQLSelectQueryBlock mysqlSelectQuery) + { + Map aliaColumns = new HashMap(); + Map aggrColumns = new HashMap(); + // Added by winbill, 20160314, for having clause, Begin ==> + List havingColsName = new ArrayList(); + // Added by winbill, 20160314, for having clause, End <== + List selectList = mysqlSelectQuery.getSelectList(); + boolean isNeedChangeSql=false; + int size = selectList.size(); + boolean isDistinct=mysqlSelectQuery.getDistionOption()==2; + for (int i = 0; i < size; i++) + { + SQLSelectItem item = selectList.get(i); + + if (item.getExpr() instanceof SQLAggregateExpr) + { + SQLAggregateExpr expr = (SQLAggregateExpr) item.getExpr(); + String method = expr.getMethodName(); + boolean isHasArgument=!expr.getArguments().isEmpty(); + if(isHasArgument) + { + String aggrColName = method + "(" + expr.getArguments().get(0) + ")"; // Added by winbill, 20160314, for having clause + havingColsName.add(aggrColName); // Added by winbill, 20160314, for having clause + } + //只处理有别名的情况,无别名添加别名,否则某些数据库会得不到正确结果处理 + int mergeType = MergeCol.getMergeType(method); + if (MergeCol.MERGE_AVG == mergeType&&isRoutMultiNode(schema,rrs)) + { //跨分片avg需要特殊处理,直接avg结果是不对的 + String colName = item.getAlias() != null ? item.getAlias() : method + i; + SQLSelectItem sum =new SQLSelectItem(); + String sumColName = colName + "SUM"; + sum.setAlias(sumColName); + SQLAggregateExpr sumExp =new SQLAggregateExpr("SUM"); + ObjectUtil.copyProperties(expr,sumExp); + sumExp.getArguments().addAll(expr.getArguments()); + sumExp.setMethodName("SUM"); + sum.setExpr(sumExp); + selectList.set(i, sum); + aggrColumns.put(sumColName, MergeCol.MERGE_SUM); + havingColsName.add(sumColName); // Added by winbill, 20160314, for having clause + havingColsName.add(item.getAlias() != null ? item.getAlias() : ""); // Added by winbill, 20160314, two aliases for AVG + + SQLSelectItem count =new SQLSelectItem(); + String countColName = colName + "COUNT"; + count.setAlias(countColName); + SQLAggregateExpr countExp = new SQLAggregateExpr("COUNT"); + ObjectUtil.copyProperties(expr,countExp); + countExp.getArguments().addAll(expr.getArguments()); + countExp.setMethodName("COUNT"); + count.setExpr(countExp); + selectList.add(count); + aggrColumns.put(countColName, MergeCol.MERGE_COUNT); + + isNeedChangeSql=true; + aggrColumns.put(colName, mergeType); + rrs.setHasAggrColumn(true); + } else if (MergeCol.MERGE_UNSUPPORT != mergeType){ + String aggColName = null; + StringBuilder sb = new StringBuilder(); + if(mysqlSelectQuery instanceof MySqlSelectQueryBlock) { + expr.accept(new MySqlOutputVisitor(sb)); + } else if(mysqlSelectQuery instanceof OracleSelectQueryBlock) { + expr.accept(new OracleOutputVisitor(sb)); + } else if(mysqlSelectQuery instanceof PGSelectQueryBlock){ + expr.accept(new PGOutputVisitor(sb)); + } else if(mysqlSelectQuery instanceof SQLServerSelectQueryBlock) { + expr.accept(new SQLASTOutputVisitor(sb)); + } else if(mysqlSelectQuery instanceof DB2SelectQueryBlock) { + expr.accept(new DB2OutputVisitor(sb)); + } + aggColName = sb.toString(); + + if (item.getAlias() != null && item.getAlias().length() > 0) + { + aggrColumns.put(item.getAlias(), mergeType); + aliaColumns.put(aggColName,item.getAlias()); + } else + { //如果不加,jdbc方式时取不到正确结果 ;修改添加别名 + item.setAlias(method + i); + aggrColumns.put(method + i, mergeType); + aliaColumns.put(aggColName, method + i); + isNeedChangeSql=true; + } + rrs.setHasAggrColumn(true); + havingColsName.add(item.getAlias()); // Added by winbill, 20160314, for having clause + havingColsName.add(""); // Added by winbill, 20160314, one alias for non-AVG + } + } else + { + if (!(item.getExpr() instanceof SQLAllColumnExpr)) + { + String alia = item.getAlias(); + String field = getFieldName(item); + if (alia == null) + { + alia = field; + } + aliaColumns.put(field, alia); + } + } + + } + if(aggrColumns.size() > 0) { + rrs.setMergeCols(aggrColumns); + } + + //通过优化转换成group by来实现 + if(isDistinct) + { + mysqlSelectQuery.setDistionOption(0); + SQLSelectGroupByClause groupBy=new SQLSelectGroupByClause(); + for (String fieldName : aliaColumns.keySet()) + { + groupBy.addItem(new SQLIdentifierExpr(fieldName)); + } + mysqlSelectQuery.setGroupBy(groupBy); + isNeedChangeSql=true; + } + + + //setGroupByCols + if(mysqlSelectQuery.getGroupBy() != null) { + List groupByItems = mysqlSelectQuery.getGroupBy().getItems(); + String[] groupByCols = buildGroupByCols(groupByItems,aliaColumns); + rrs.setGroupByCols(groupByCols); + rrs.setHavings(buildGroupByHaving(mysqlSelectQuery.getGroupBy().getHaving(),aliaColumns)); + rrs.setHasAggrColumn(true); + rrs.setHavingColsName(havingColsName.toArray()); // Added by winbill, 20160314, for having clause + } + + + if (isNeedChangeSql) + { + String sql = stmt.toString(); + rrs.changeNodeSqlAfterAddLimit(schema,getCurentDbType(),sql,0,-1, false); + getCtx().setSql(sql); + } + return aliaColumns; + } + + private HavingCols buildGroupByHaving(SQLExpr having,Map aliaColumns ){ + if (having == null) { + return null; + } + + SQLBinaryOpExpr expr = ((SQLBinaryOpExpr) having); + SQLExpr left = expr.getLeft(); + SQLBinaryOperator operator = expr.getOperator(); + SQLExpr right = expr.getRight(); + + String leftValue = null;; + if (left instanceof SQLAggregateExpr) { + leftValue = ((SQLAggregateExpr) left).getMethodName() + "(" + + ((SQLAggregateExpr) left).getArguments().get(0) + ")"; + String aggrColumnAlias = getAliaColumn(aliaColumns,leftValue); + if(aggrColumnAlias != null) { // having聚合函数存在别名 + expr.setLeft(new SQLIdentifierExpr(aggrColumnAlias)); + leftValue = aggrColumnAlias; + } + } else if (left instanceof SQLIdentifierExpr) { + leftValue = ((SQLIdentifierExpr) left).getName(); + } + + String rightValue = null; + if (right instanceof SQLNumericLiteralExpr) { + rightValue = right.toString(); + }else if(right instanceof SQLTextLiteralExpr){ + rightValue = StringUtil.removeBackquote(right.toString()); + } + + return new HavingCols(leftValue,rightValue,operator.getName()); + } + + private boolean isRoutMultiNode(SchemaConfig schema, RouteResultset rrs) + { + if(rrs.getNodes()!=null&&rrs.getNodes().length>1) + { + return true; + } + LayerCachePool tableId2DataNodeCache = (LayerCachePool) MycatServer.getInstance().getCacheService().getCachePool("TableID2DataNodeCache"); + try + { + tryRoute(schema, rrs, tableId2DataNodeCache); + if(rrs.getNodes()!=null&&rrs.getNodes().length>1) + { + return true; + } + } catch (SQLNonTransientException e) + { + throw new RuntimeException(e); + } + return false; + } + + private String getFieldName(SQLSelectItem item){ + if ((item.getExpr() instanceof SQLPropertyExpr)||(item.getExpr() instanceof SQLMethodInvokeExpr) + || (item.getExpr() instanceof SQLIdentifierExpr) || item.getExpr() instanceof SQLBinaryOpExpr) { + return item.getExpr().toString();//字段别名 + } + else { + return item.toString(); + } + } + + /** + * 现阶段目标为 有一个只涉及到一张表的子查询时,先执行子查询,获得返回结果后,改写原有sql继续执行,得到最终结果. + * 在这种情况下,原sql不需要继续解析. + * 使用catlet 的情况也不再继续解析. + */ + @Override + public boolean afterVisitorParser(RouteResultset rrs, SQLStatement stmt, MycatSchemaStatVisitor visitor) { + int subQuerySize = visitor.getSubQuerys().size(); + + if(subQuerySize==0&&ctx.getTables().size()==2){ //两表关联,考虑使用catlet + if(ctx.getVisitor().getConditions() !=null && ctx.getVisitor().getConditions().size()>0){ + return true; + } + }else if(subQuerySize==1){ //只涉及一张表的子查询,使用 MiddlerResultHandler 获取中间结果后,改写原有 sql 继续执行 TODO 后期可能会考虑多个. + SQLSelectQuery sqlSelectQuery = visitor.getSubQuerys().iterator().next().getQuery(); + if(((MySqlSelectQueryBlock)sqlSelectQuery).getFrom() instanceof SQLExprTableSource) { + return true; + } + } + + return super.afterVisitorParser(rrs, stmt, visitor); + } + + /** + * 改写sql:需要加limit的加上 + */ + @Override + public void changeSql(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt,LayerCachePool cachePool) throws SQLNonTransientException { + + tryRoute(schema, rrs, cachePool); + + rrs.copyLimitToNodes(); + + SQLSelectStatement selectStmt = (SQLSelectStatement)stmt; + SQLSelectQuery sqlSelectQuery = selectStmt.getSelect().getQuery(); + if(sqlSelectQuery instanceof MySqlSelectQueryBlock) { + MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)selectStmt.getSelect().getQuery(); + int limitStart = 0; + int limitSize = schema.getDefaultMaxLimit(); + + //clear group having + SQLSelectGroupByClause groupByClause = mysqlSelectQuery.getGroupBy(); + // Modified by winbill, 20160614, do NOT include having clause when routing to multiple nodes + if(groupByClause != null && groupByClause.getHaving() != null && isRoutMultiNode(schema,rrs)){ + groupByClause.setHaving(null); + } + + Map>> allConditions = getAllConditions(); + boolean isNeedAddLimit = isNeedAddLimit(schema, rrs, mysqlSelectQuery, allConditions); + if(isNeedAddLimit) { + Limit limit = new Limit(); + limit.setRowCount(new SQLIntegerExpr(limitSize)); + mysqlSelectQuery.setLimit(limit); + rrs.setLimitSize(limitSize); + String sql= getSql(rrs, stmt, isNeedAddLimit); + rrs.changeNodeSqlAfterAddLimit(schema, getCurentDbType(), sql, 0, limitSize, true); + + } + Limit limit = mysqlSelectQuery.getLimit(); + if(limit != null&&!isNeedAddLimit) { + SQLIntegerExpr offset = (SQLIntegerExpr)limit.getOffset(); + SQLIntegerExpr count = (SQLIntegerExpr)limit.getRowCount(); + if(offset != null) { + limitStart = offset.getNumber().intValue(); + rrs.setLimitStart(limitStart); + } + if(count != null) { + limitSize = count.getNumber().intValue(); + rrs.setLimitSize(limitSize); + } + + if(isNeedChangeLimit(rrs)) { + Limit changedLimit = new Limit(); + changedLimit.setRowCount(new SQLIntegerExpr(limitStart + limitSize)); + + if(offset != null) { + if(limitStart < 0) { + String msg = "You have an error in your SQL syntax; check the manual that " + + "corresponds to your MySQL server version for the right syntax to use near '" + limitStart + "'"; + throw new SQLNonTransientException(ErrorCode.ER_PARSE_ERROR + " - " + msg); + } else { + changedLimit.setOffset(new SQLIntegerExpr(0)); + + } + } + + mysqlSelectQuery.setLimit(changedLimit); + + String sql= getSql(rrs, stmt, isNeedAddLimit); + rrs.changeNodeSqlAfterAddLimit(schema,getCurentDbType(),sql,0, limitStart + limitSize, true); + + //设置改写后的sql + ctx.setSql(sql); + + } else + { + + rrs.changeNodeSqlAfterAddLimit(schema,getCurentDbType(),getCtx().getSql(),rrs.getLimitStart(), rrs.getLimitSize(), true); + // ctx.setSql(nativeSql); + + } + + + } + + if(rrs.isDistTable()){ + SQLTableSource from = mysqlSelectQuery.getFrom(); + + for (RouteResultsetNode node : rrs.getNodes()) { + SQLIdentifierExpr sqlIdentifierExpr = new SQLIdentifierExpr(); + sqlIdentifierExpr.setParent(from); + sqlIdentifierExpr.setName(node.getSubTableName()); + SQLExprTableSource from2 = new SQLExprTableSource(sqlIdentifierExpr); + from2.setAlias(from.getAlias()); + mysqlSelectQuery.setFrom(from2); + node.setStatement(stmt.toString()); + } + } + + rrs.setCacheAble(isNeedCache(schema, rrs, mysqlSelectQuery, allConditions)); + } + + } + + /** + * 获取所有的条件:因为可能被or语句拆分成多个RouteCalculateUnit,条件分散了 + * @return + */ + private Map>> getAllConditions() { + Map>> map = new HashMap>>(); + for(RouteCalculateUnit unit : ctx.getRouteCalculateUnits()) { + if(unit != null && unit.getTablesAndConditions() != null) { + map.putAll(unit.getTablesAndConditions()); + } + } + + return map; + } + + private void tryRoute(SchemaConfig schema, RouteResultset rrs, LayerCachePool cachePool) throws SQLNonTransientException { + if(rrs.isFinishedRoute()) + { + return;//避免重复路由 + } + + //无表的select语句直接路由带任一节点 + if((ctx.getTables() == null || ctx.getTables().size() == 0)&&(ctx.getTableAliasMap()==null||ctx.getTableAliasMap().isEmpty())) { + rrs = RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), ctx.getSql()); + rrs.setFinishedRoute(true); + return; + } +// RouterUtil.tryRouteForTables(schema, ctx, rrs, true, cachePool); + SortedSet nodeSet = new TreeSet(); + boolean isAllGlobalTable = RouterUtil.isAllGlobalTable(ctx, schema); + for (RouteCalculateUnit unit : ctx.getRouteCalculateUnits()) { + RouteResultset rrsTmp = RouterUtil.tryRouteForTables(schema, ctx, unit, rrs, true, cachePool); + if (rrsTmp != null&&rrsTmp.getNodes()!=null) { + for (RouteResultsetNode node : rrsTmp.getNodes()) { + nodeSet.add(node); + } + } + if(isAllGlobalTable) {//都是全局表时只计算一遍路由 + break; + } + } + + if(nodeSet.size() == 0) { + + Collection stringCollection= ctx.getTableAliasMap().values() ; + for (String table : stringCollection) + { + if(table!=null&&table.toLowerCase().contains("information_schema.")) + { + rrs = RouterUtil.routeToSingleNode(rrs, schema.getRandomDataNode(), ctx.getSql()); + rrs.setFinishedRoute(true); + return; + } + } + String msg = " find no Route:" + ctx.getSql(); + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + + RouteResultsetNode[] nodes = new RouteResultsetNode[nodeSet.size()]; + int i = 0; + for (Iterator iterator = nodeSet.iterator(); iterator.hasNext();) { + nodes[i] = (RouteResultsetNode) iterator.next(); + i++; + + } + + rrs.setNodes(nodes); + rrs.setFinishedRoute(true); + } + + + protected String getCurentDbType() + { + return JdbcConstants.MYSQL; + } + + + + + protected String getSql( RouteResultset rrs,SQLStatement stmt, boolean isNeedAddLimit) + { + if(getCurentDbType().equalsIgnoreCase("mysql")&&(isNeedChangeLimit(rrs)||isNeedAddLimit)) + { + + return stmt.toString(); + + } + + return getCtx().getSql(); + } + + + + protected boolean isNeedChangeLimit(RouteResultset rrs) { + if(rrs.getNodes() == null) { + return false; + } else { + if(rrs.getNodes().length > 1) { + return true; + } + return false; + + } + } + + private boolean isNeedCache(SchemaConfig schema, RouteResultset rrs, + MySqlSelectQueryBlock mysqlSelectQuery, Map>> allConditions) { + if(ctx.getTables() == null || ctx.getTables().size() == 0 ) { + return false; + } + TableConfig tc = schema.getTables().get(ctx.getTables().get(0)); + if(tc==null ||(ctx.getTables().size() == 1 && tc.isGlobalTable()) + ) {//|| (ctx.getTables().size() == 1) && tc.getRule() == null && tc.getDataNodes().size() == 1 + return false; + } else { + //单表主键查询 + if(ctx.getTables().size() == 1) { + String tableName = ctx.getTables().get(0); + String primaryKey = schema.getTables().get(tableName).getPrimaryKey(); +// schema.getTables().get(ctx.getTables().get(0)).getParentKey() != null; + if(ctx.getRouteCalculateUnit().getTablesAndConditions().get(tableName) != null + && ctx.getRouteCalculateUnit().getTablesAndConditions().get(tableName).get(primaryKey) != null + && tc.getDataNodes().size() > 1) {//有主键条件 + return false; + } + //全局表不缓存 + }else if(RouterUtil.isAllGlobalTable(ctx, schema)){ + return false; + } + return true; + } + } + + /** + * 单表且是全局表 + * 单表且rule为空且nodeNodes只有一个 + * @param schema + * @param rrs + * @param mysqlSelectQuery + * @return + */ + private boolean isNeedAddLimit(SchemaConfig schema, RouteResultset rrs, + MySqlSelectQueryBlock mysqlSelectQuery, Map>> allConditions) { +// ctx.getTablesAndConditions().get(key)) + if(rrs.getLimitSize()>-1) + { + return false; + }else + if(schema.getDefaultMaxLimit() == -1) { + return false; + } else if (mysqlSelectQuery.getLimit() != null) {//语句中已有limit + return false; + } else if(ctx.getTables().size() == 1) { + String tableName = ctx.getTables().get(0); + TableConfig tableConfig = schema.getTables().get(tableName); + if(tableConfig==null) + { + return schema.getDefaultMaxLimit() > -1; // 找不到则取schema的配置 + } + + boolean isNeedAddLimit= tableConfig.isNeedAddLimit(); + if(!isNeedAddLimit) + { + return false;//优先从配置文件取 + } + + if(schema.getTables().get(tableName).isGlobalTable()) { + return true; + } + + String primaryKey = schema.getTables().get(tableName).getPrimaryKey(); + +// schema.getTables().get(ctx.getTables().get(0)).getParentKey() != null; + if(allConditions.get(tableName) == null) {//无条件 + return true; + } + + if (allConditions.get(tableName).get(primaryKey) != null) {//条件中带主键 + return false; + } + + return true; + } else if(rrs.hasPrimaryKeyToCache() && ctx.getTables().size() == 1){//只有一个表且条件中有主键,不需要limit了,因为主键只能查到一条记录 + return false; + } else {//多表或无表 + return false; + } + + } + private String getAliaColumn(Map aliaColumns,String column ){ + String alia=aliaColumns.get(column); + if (alia==null){ + if(column.indexOf(".") < 0) { + String col = "." + column; + String col2 = ".`" + column+"`"; + //展开aliaColumns,将之类的键值对展开成 + for(Map.Entry entry : aliaColumns.entrySet()) { + if(entry.getKey().endsWith(col)||entry.getKey().endsWith(col2)) { + if(entry.getValue() != null && entry.getValue().indexOf(".") > 0) { + return column; + } + return entry.getValue(); + } + } + } + + return column; + } + else { + return alia; + } + } + + private String[] buildGroupByCols(List groupByItems,Map aliaColumns) { + String[] groupByCols = new String[groupByItems.size()]; + for(int i= 0; i < groupByItems.size(); i++) { + SQLExpr sqlExpr = groupByItems.get(i); + String column = null; + if(sqlExpr instanceof SQLIdentifierExpr ) + { + column=((SQLIdentifierExpr) sqlExpr).getName(); + } else if(sqlExpr instanceof SQLMethodInvokeExpr){ + column = ((SQLMethodInvokeExpr) sqlExpr).toString(); + } else if(sqlExpr instanceof MySqlOrderingExpr){ + //todo czn + SQLExpr expr = ((MySqlOrderingExpr) sqlExpr).getExpr(); + + if (expr instanceof SQLName) + { + column = StringUtil.removeBackquote(((SQLName) expr).getSimpleName());//不要转大写 2015-2-10 sohudo StringUtil.removeBackquote(expr.getSimpleName().toUpperCase()); + } else + { + column = StringUtil.removeBackquote(expr.toString()); + } + } else if(sqlExpr instanceof SQLPropertyExpr){ + /** + * 针对子查询别名,例如select id from (select h.id from hotnews h union select h.title from hotnews h ) as t1 group by t1.id; + */ + column = sqlExpr.toString(); + } + if(column == null){ + column = sqlExpr.toString(); + } + int dotIndex=column.indexOf(".") ; + int bracketIndex=column.indexOf("(") ; + //通过判断含有括号来决定是否为函数列 + if(dotIndex!=-1&&bracketIndex==-1) + { + //此步骤得到的column必须是不带.的,有别名的用别名,无别名的用字段名 + column=column.substring(dotIndex+1) ; + } + groupByCols[i] = getAliaColumn(aliaColumns,column);//column; + } + return groupByCols; + } + + protected LinkedHashMap buildOrderByCols(List orderByItems,Map aliaColumns) { + LinkedHashMap map = new LinkedHashMap(); + for(int i= 0; i < orderByItems.size(); i++) { + SQLOrderingSpecification type = orderByItems.get(i).getType(); + //orderColumn只记录字段名称,因为返回的结果集是不带表名的。 + SQLExpr expr = orderByItems.get(i).getExpr(); + String col; + if (expr instanceof SQLName) { + col = ((SQLName)expr).getSimpleName(); + } + else { + col =expr.toString(); + } + if(type == null) { + type = SQLOrderingSpecification.ASC; + } + col=getAliaColumn(aliaColumns,col);//此步骤得到的col必须是不带.的,有别名的用别名,无别名的用字段名 + map.put(col, type == SQLOrderingSpecification.ASC ? OrderCol.COL_ORDER_TYPE_ASC : OrderCol.COL_ORDER_TYPE_DESC); + } + return map; + } + + private boolean isConditionAlwaysTrue(SQLStatement statement) { + SQLSelectStatement selectStmt = (SQLSelectStatement)statement; + SQLSelectQuery sqlSelectQuery = selectStmt.getSelect().getQuery(); + if(sqlSelectQuery instanceof MySqlSelectQueryBlock) { + MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)selectStmt.getSelect().getQuery(); + SQLExpr expr = mysqlSelectQuery.getWhere(); + + Object o = WallVisitorUtils.getValue(expr); + if(Boolean.TRUE.equals(o)) { + return true; + } + return false; + } else {//union + return false; + } + + } + + protected void setLimitIFChange(SQLStatement stmt, RouteResultset rrs, SchemaConfig schema, SQLBinaryOpExpr one, int firstrownum, int lastrownum) + { + rrs.setLimitStart(firstrownum); + rrs.setLimitSize(lastrownum - firstrownum); + LayerCachePool tableId2DataNodeCache = (LayerCachePool) MycatServer.getInstance().getCacheService().getCachePool("TableID2DataNodeCache"); + try + { + tryRoute(schema, rrs, tableId2DataNodeCache); + } catch (SQLNonTransientException e) + { + throw new RuntimeException(e); + } + if (isNeedChangeLimit(rrs)) + { + one.setRight(new SQLIntegerExpr(0)); + String curentDbType ="db2".equalsIgnoreCase(this.getCurentDbType())?"oracle":getCurentDbType(); + String sql = SQLUtils.toSQLString(stmt, curentDbType);; + rrs.changeNodeSqlAfterAddLimit(schema,getCurentDbType(), sql,0,lastrownum, false); + //设置改写后的sql + getCtx().setSql(sql); + } + } +} diff --git a/src/main/java/io/mycat/route/parser/util/WildcardUtil.java b/src/main/java/io/mycat/route/parser/util/WildcardUtil.java deleted file mode 100644 index 25a8f70b7..000000000 --- a/src/main/java/io/mycat/route/parser/util/WildcardUtil.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.mycat.route.parser.util; - -public class WildcardUtil { - - public static String wildcard(String name) { - if (name.startsWith("`")) { - name = name.replaceAll("`", ""); - } else if (name.startsWith("\"")) { - name = name.replaceAll("\"", ""); - } else if (name.startsWith("'")) { - name = name.replaceAll("'", ""); - } - return name; - } - - public static void wildcards(String[] names) { - for (int i = 0; i < names.length; i++) { - names[i] = wildcard(names[i]); - } - } -} diff --git a/src/main/java/io/mycat/route/util/RouterUtil.java b/src/main/java/io/mycat/route/util/RouterUtil.java index c432334ca..ec300e8ab 100644 --- a/src/main/java/io/mycat/route/util/RouterUtil.java +++ b/src/main/java/io/mycat/route/util/RouterUtil.java @@ -1,1993 +1,1928 @@ -package io.mycat.route.util; - -import java.sql.SQLNonTransientException; -import java.sql.SQLSyntaxErrorException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.LinkedHashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.Callable; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.alibaba.druid.sql.ast.SQLExpr; -import com.alibaba.druid.sql.ast.SQLStatement; -import com.alibaba.druid.sql.ast.expr.SQLCharExpr; -import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; -import com.alibaba.druid.sql.ast.statement.SQLCharacterDataType; -import com.alibaba.druid.sql.ast.statement.SQLColumnDefinition; -import com.alibaba.druid.sql.ast.statement.SQLCreateTableStatement; -import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlCreateTableStatement; -import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; -import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; -import com.alibaba.druid.wall.spi.WallVisitorUtils; -import com.google.common.base.Strings; -import com.google.common.collect.Maps; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; - -import io.mycat.MycatServer; -import io.mycat.backend.datasource.PhysicalDBNode; -import io.mycat.backend.datasource.PhysicalDBPool; -import io.mycat.backend.datasource.PhysicalDatasource; -import io.mycat.backend.mysql.nio.handler.FetchStoreNodeOfChildTableHandler; -import io.mycat.cache.LayerCachePool; -import io.mycat.config.ErrorCode; -import io.mycat.config.MycatConfig; -import io.mycat.config.model.SchemaConfig; -import io.mycat.config.model.TableConfig; -import io.mycat.config.model.rule.RuleConfig; -import io.mycat.route.RouteResultset; -import io.mycat.route.RouteResultsetNode; -import io.mycat.route.SessionSQLPair; -import io.mycat.route.function.AbstractPartitionAlgorithm; -import io.mycat.route.function.SlotFunction; -import io.mycat.route.parser.druid.DruidShardingParseInfo; -import io.mycat.route.parser.druid.RouteCalculateUnit; -import io.mycat.server.ServerConnection; -import io.mycat.server.parser.ServerParse; -import io.mycat.sqlengine.mpp.ColumnRoutePair; -import io.mycat.sqlengine.mpp.LoadData; -import io.mycat.util.StringUtil; - -/** - * 从ServerRouterUtil中抽取的一些公用方法,路由解析工具类 - * @author wang.dw - * - */ -public class RouterUtil { - - private static final Logger LOGGER = LoggerFactory.getLogger(RouterUtil.class); - - /** - * 移除执行语句中的数据库名 - * - * @param stmt 执行语句 - * @param schema 数据库名 - * @return 执行语句 - * @author mycat - * - * @modification 修正移除schema的方法 - * @date 2016/12/29 - * @modifiedBy Hash Zhang - * - */ - public static String removeSchema(String stmt, String schema) { - final String upStmt = stmt.toUpperCase(); - final String upSchema = schema.toUpperCase() + "."; - final String upSchema2 = new StringBuilder("`").append(schema.toUpperCase()).append("`.").toString(); - int strtPos = 0; - int indx = 0; - - int indx1 = upStmt.indexOf(upSchema, strtPos); - int indx2 = upStmt.indexOf(upSchema2, strtPos); - boolean flag = indx1 < indx2 ? indx1 == -1 : indx2 != -1; - indx = !flag ? indx1 > 0 ? indx1 : indx2 : indx2 > 0 ? indx2 : indx1; - if (indx < 0) { - return stmt; - } - - int firstE = upStmt.indexOf("'"); - int endE = upStmt.lastIndexOf("'"); - - StringBuilder sb = new StringBuilder(); - while (indx > 0) { - sb.append(stmt.substring(strtPos, indx)); - - if (flag) { - strtPos = indx + upSchema2.length(); - } else { - strtPos = indx + upSchema.length(); - } - if (indx > firstE && indx < endE && countChar(stmt, indx) % 2 == 1) { - sb.append(stmt.substring(indx, indx + schema.length() + 1)); - } - indx1 = upStmt.indexOf(upSchema, strtPos); - indx2 = upStmt.indexOf(upSchema2, strtPos); - flag = indx1 < indx2 ? indx1 == -1 : indx2 != -1; - indx = !flag ? indx1 > 0 ? indx1 : indx2 : indx2 > 0 ? indx2 : indx1; - } - sb.append(stmt.substring(strtPos)); - return sb.toString(); - } - - private static int countChar(String sql,int end) - { - int count=0; - boolean skipChar = false; - for (int i = 0; i < end; i++) { - if(sql.charAt(i)=='\'' && !skipChar) { - count++; - skipChar = false; - }else if( sql.charAt(i)=='\\'){ - skipChar = true; - }else{ - skipChar = false; - } - } - return count; - } - - /** - * 获取第一个节点作为路由 - * - * @param rrs 数据路由集合 - * @param dataNode 数据库所在节点 - * @param stmt 执行语句 - * @return 数据路由集合 - * - * @author mycat - */ - public static RouteResultset routeToSingleNode(RouteResultset rrs, - String dataNode, String stmt) { - if (dataNode == null) { - return rrs; - } - RouteResultsetNode[] nodes = new RouteResultsetNode[1]; - nodes[0] = new RouteResultsetNode(dataNode, rrs.getSqlType(), stmt);//rrs.getStatement() - nodes[0].setSource(rrs); - rrs.setNodes(nodes); - rrs.setFinishedRoute(true); - if(rrs.getDataNodeSlotMap().containsKey(dataNode)){ - nodes[0].setSlot(rrs.getDataNodeSlotMap().get(dataNode)); - } - if (rrs.getCanRunInReadDB() != null) { - nodes[0].setCanRunInReadDB(rrs.getCanRunInReadDB()); - } - if(rrs.getRunOnSlave() != null){ - nodes[0].setRunOnSlave(rrs.getRunOnSlave()); - } - - return rrs; - } - - - - /** - * 修复DDL路由 - * - * @return RouteResultset - * @author aStoneGod - */ - public static RouteResultset routeToDDLNode(RouteResultset rrs, int sqlType, String stmt,SchemaConfig schema) throws SQLSyntaxErrorException { - stmt = getFixedSql(stmt); - String tablename = ""; - final String upStmt = stmt.toUpperCase(); - if(upStmt.startsWith("CREATE")){ - if (upStmt.contains("CREATE INDEX ") || upStmt.contains("CREATE UNIQUE INDEX ")){ - tablename = RouterUtil.getTableName(stmt, RouterUtil.getCreateIndexPos(upStmt, 0)); - }else { - tablename = RouterUtil.getTableName(stmt, RouterUtil.getCreateTablePos(upStmt, 0)); - } - }else if(upStmt.startsWith("DROP")){ - if (upStmt.contains("DROP INDEX ")){ - tablename = RouterUtil.getTableName(stmt, RouterUtil.getDropIndexPos(upStmt, 0)); - }else { - tablename = RouterUtil.getTableName(stmt, RouterUtil.getDropTablePos(upStmt, 0)); - } - }else if(upStmt.startsWith("ALTER")){ - tablename = RouterUtil.getTableName(stmt, RouterUtil.getAlterTablePos(upStmt, 0)); - }else if (upStmt.startsWith("TRUNCATE")){ - tablename = RouterUtil.getTableName(stmt, RouterUtil.getTruncateTablePos(upStmt, 0)); - } - tablename = tablename.toUpperCase(); - - if (schema.getTables().containsKey(tablename)){ - if(ServerParse.DDL==sqlType){ - List dataNodes = new ArrayList<>(); - Map tables = schema.getTables(); - TableConfig tc=tables.get(tablename); - if (tables != null && (tc != null)) { - dataNodes = tc.getDataNodes(); - } - boolean isSlotFunction= tc.getRule() != null && tc.getRule().getRuleAlgorithm() instanceof SlotFunction; - Iterator iterator1 = dataNodes.iterator(); - int nodeSize = dataNodes.size(); - RouteResultsetNode[] nodes = new RouteResultsetNode[nodeSize]; - if(isSlotFunction){ - stmt=changeCreateTable(schema,tablename,stmt); - } - for(int i=0;i 0) { - tableName = tableName.substring(ind2 + 1); - } - return tableName; - } - - - /** - * 获取show语句table名字 - * - * @param stmt 执行语句 - * @param repPos 开始位置和位数 - * @return 表名 - * @author AStoneGod - */ - public static String getShowTableName(String stmt, int[] repPos) { - int startPos = repPos[0]; - int secInd = stmt.indexOf(' ', startPos + 1); - if (secInd < 0) { - secInd = stmt.length(); - } - - repPos[1] = secInd; - String tableName = stmt.substring(startPos, secInd).trim(); - - int ind2 = tableName.indexOf('.'); - if (ind2 > 0) { - tableName = tableName.substring(ind2 + 1); - } - return tableName; - } - - /** - * 获取语句中前关键字位置和占位个数表名位置 - * - * @param upStmt 执行语句 - * @param start 开始位置 - * @return int[] 关键字位置和占位个数 - * - * @author mycat - * - * @modification 修改支持语句中包含“IF NOT EXISTS”的情况 - * @date 2016/12/8 - * @modifiedBy Hash Zhang - */ - public static int[] getCreateTablePos(String upStmt, int start) { - String token1 = "CREATE "; - String token2 = " TABLE "; - String token3 = " EXISTS "; - int createInd = upStmt.indexOf(token1, start); - int tabInd1 = upStmt.indexOf(token2, start); - int tabInd2 = upStmt.indexOf(token3, tabInd1); - // 既包含CREATE又包含TABLE,且CREATE关键字在TABLE关键字之前 - if (createInd >= 0 && tabInd2 > 0 && tabInd2 > createInd) { - return new int[] { tabInd2, token3.length() }; - } else if(createInd >= 0 && tabInd1 > 0 && tabInd1 > createInd) { - return new int[] { tabInd1, token2.length() }; - } else { - return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 - } - } - - /** - * 获取语句中前关键字位置和占位个数表名位置 - * - * @param upStmt - * 执行语句 - * @param start - * 开始位置 - * @return int[]关键字位置和占位个数 - * @author aStoneGod - */ - public static int[] getCreateIndexPos(String upStmt, int start) { - String token1 = "CREATE "; - String token2 = " INDEX "; - String token3 = " ON "; - int createInd = upStmt.indexOf(token1, start); - int idxInd = upStmt.indexOf(token2, start); - int onInd = upStmt.indexOf(token3, start); - // 既包含CREATE又包含INDEX,且CREATE关键字在INDEX关键字之前, 且包含ON... - if (createInd >= 0 && idxInd > 0 && idxInd > createInd && onInd > 0 && onInd > idxInd) { - return new int[] {onInd , token3.length() }; - } else { - return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 - } - } - - /** - * 获取ALTER语句中前关键字位置和占位个数表名位置 - * - * @param upStmt 执行语句 - * @param start 开始位置 - * @return int[] 关键字位置和占位个数 - * @author aStoneGod - */ - public static int[] getAlterTablePos(String upStmt, int start) { - String token1 = "ALTER "; - String token2 = " TABLE "; - int createInd = upStmt.indexOf(token1, start); - int tabInd = upStmt.indexOf(token2, start); - // 既包含CREATE又包含TABLE,且CREATE关键字在TABLE关键字之前 - if (createInd >= 0 && tabInd > 0 && tabInd > createInd) { - return new int[] { tabInd, token2.length() }; - } else { - return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 - } - } - - /** - * 获取DROP语句中前关键字位置和占位个数表名位置 - * - * @param upStmt 执行语句 - * @param start 开始位置 - * @return int[] 关键字位置和占位个数 - * @author aStoneGod - */ - public static int[] getDropTablePos(String upStmt, int start) { - //增加 if exists判断 - if(upStmt.contains("EXISTS")){ - String token1 = "IF "; - String token2 = " EXISTS "; - int ifInd = upStmt.indexOf(token1, start); - int tabInd = upStmt.indexOf(token2, start); - if (ifInd >= 0 && tabInd > 0 && tabInd > ifInd) { - return new int[] { tabInd, token2.length() }; - } else { - return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 - } - }else { - String token1 = "DROP "; - String token2 = " TABLE "; - int createInd = upStmt.indexOf(token1, start); - int tabInd = upStmt.indexOf(token2, start); - - if (createInd >= 0 && tabInd > 0 && tabInd > createInd) { - return new int[] { tabInd, token2.length() }; - } else { - return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 - } - } - } - - - /** - * 获取DROP语句中前关键字位置和占位个数表名位置 - * - * @param upStmt - * 执行语句 - * @param start - * 开始位置 - * @return int[]关键字位置和占位个数 - * @author aStoneGod - */ - - public static int[] getDropIndexPos(String upStmt, int start) { - String token1 = "DROP "; - String token2 = " INDEX "; - String token3 = " ON "; - int createInd = upStmt.indexOf(token1, start); - int idxInd = upStmt.indexOf(token2, start); - int onInd = upStmt.indexOf(token3, start); - // 既包含CREATE又包含INDEX,且CREATE关键字在INDEX关键字之前, 且包含ON... - if (createInd >= 0 && idxInd > 0 && idxInd > createInd && onInd > 0 && onInd > idxInd) { - return new int[] {onInd , token3.length() }; - } else { - return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 - } - } - - /** - * 获取TRUNCATE语句中前关键字位置和占位个数表名位置 - * - * @param upStmt 执行语句 - * @param start 开始位置 - * @return int[] 关键字位置和占位个数 - * @author aStoneGod - */ - public static int[] getTruncateTablePos(String upStmt, int start) { - String token1 = "TRUNCATE "; - String token2 = " TABLE "; - int createInd = upStmt.indexOf(token1, start); - int tabInd = upStmt.indexOf(token2, start); - // 既包含CREATE又包含TABLE,且CREATE关键字在TABLE关键字之前 - if (createInd >= 0 && tabInd > 0 && tabInd > createInd) { - return new int[] { tabInd, token2.length() }; - } else { - return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 - } - } - - /** - * 获取语句中前关键字位置和占位个数表名位置 - * - * @param upStmt 执行语句 - * @param start 开始位置 - * @return int[] 关键字位置和占位个数 - * @author mycat - */ - public static int[] getSpecPos(String upStmt, int start) { - String token1 = " FROM "; - String token2 = " IN "; - int tabInd1 = upStmt.indexOf(token1, start); - int tabInd2 = upStmt.indexOf(token2, start); - if (tabInd1 > 0) { - if (tabInd2 < 0) { - return new int[] { tabInd1, token1.length() }; - } - return (tabInd1 < tabInd2) ? new int[] { tabInd1, token1.length() } - : new int[] { tabInd2, token2.length() }; - } else { - return new int[] { tabInd2, token2.length() }; - } - } - - /** - * 获取开始位置后的 LIKE、WHERE 位置 如果不含 LIKE、WHERE 则返回执行语句的长度 - * - * @param upStmt 执行sql - * @param start 开始位置 - * @return int - * @author mycat - */ - public static int getSpecEndPos(String upStmt, int start) { - int tabInd = upStmt.toUpperCase().indexOf(" LIKE ", start); - if (tabInd < 0) { - tabInd = upStmt.toUpperCase().indexOf(" WHERE ", start); - } - if (tabInd < 0) { - return upStmt.length(); - } - return tabInd; - } - - public static boolean processWithMycatSeq(SchemaConfig schema, int sqlType, - String origSQL, ServerConnection sc) { - // check if origSQL is with global sequence - // @micmiu it is just a simple judgement - //对应本地文件配置方式:insert into table1(id,name) values(next value for MYCATSEQ_GLOBAL,‘test’); - // edit by dingw,增加mycatseq_ 兼容,因为ServerConnection的373行,进行路由计算时,将原始语句全部转换为小写 - if (origSQL.indexOf(" MYCATSEQ_") != -1 || origSQL.indexOf("mycatseq_") != -1) { - processSQL(sc,schema,origSQL,sqlType); - return true; - } - return false; - } - - public static void processSQL(ServerConnection sc,SchemaConfig schema,String sql,int sqlType){ -// int sequenceHandlerType = MycatServer.getInstance().getConfig().getSystem().getSequnceHandlerType(); - final SessionSQLPair sessionSQLPair = new SessionSQLPair(sc.getSession2(), schema, sql, sqlType); -// modify by yanjunli 序列获取修改为多线程方式。使用分段锁方式,一个序列一把锁。 begin -// MycatServer.getInstance().getSequnceProcessor().addNewSql(sessionSQLPair); - MycatServer.getInstance().getSequenceExecutor().execute(new Runnable() { - @Override - public void run() { - MycatServer.getInstance().getSequnceProcessor().executeSeq(sessionSQLPair); - } - }); -// modify 序列获取修改为多线程方式。使用分段锁方式,一个序列一把锁。 end -// } - } - - public static boolean processInsert(SchemaConfig schema, int sqlType, - String origSQL, ServerConnection sc) throws SQLNonTransientException { - String tableName = StringUtil.getTableName(origSQL).toUpperCase(); - TableConfig tableConfig = schema.getTables().get(tableName); - boolean processedInsert=false; - //判断是有自增字段 - if (null != tableConfig && tableConfig.isAutoIncrement()) { - String primaryKey = tableConfig.getPrimaryKey(); - processedInsert=processInsert(sc,schema,sqlType,origSQL,tableName,primaryKey); - } - return processedInsert; - } - /* - * 找到返回主键的的位置 - * 找不到返回 -1 - * */ - private static int isPKInFields(String origSQL,String primaryKey,int firstLeftBracketIndex,int firstRightBracketIndex){ - - if (primaryKey == null) { - throw new RuntimeException("please make sure the primaryKey's config is not null in schemal.xml"); - } - - boolean isPrimaryKeyInFields = false; - int pkStart = 0; - String upperSQL = origSQL.substring(firstLeftBracketIndex, firstRightBracketIndex + 1).toUpperCase(); - for (int pkOffset = 0, primaryKeyLength = primaryKey.length();;) { - pkStart = upperSQL.indexOf(primaryKey, pkOffset); - if (pkStart >= 0 && pkStart < firstRightBracketIndex) { - char pkSide = upperSQL.charAt(pkStart - 1); - if (pkSide <= ' ' || pkSide == '`' || pkSide == ',' || pkSide == '(') { - pkSide = upperSQL.charAt(pkStart + primaryKey.length()); - isPrimaryKeyInFields = pkSide <= ' ' || pkSide == '`' || pkSide == ',' || pkSide == ')'; - } - if (isPrimaryKeyInFields) { - break; - } - pkOffset = pkStart + primaryKeyLength; - } else { - break; - } - } - if (isPrimaryKeyInFields) { - return firstLeftBracketIndex + pkStart; - } else { - return -1; - } - - } - - public static boolean processInsert(ServerConnection sc,SchemaConfig schema, - int sqlType,String origSQL,String tableName,String primaryKey) throws SQLNonTransientException { - - int firstLeftBracketIndex = origSQL.indexOf("("); - int firstRightBracketIndex = origSQL.indexOf(")"); - String upperSql = origSQL.toUpperCase(); - int valuesIndex = upperSql.indexOf("VALUES"); - int selectIndex = upperSql.indexOf("SELECT"); - int fromIndex = upperSql.indexOf("FROM"); - //屏蔽insert into table1 select * from table2语句 - if(firstLeftBracketIndex < 0) { - String msg = "invalid sql:" + origSQL; - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - //屏蔽批量插入 - if(selectIndex > 0 &&fromIndex>0&&selectIndex>firstRightBracketIndex&&valuesIndex<0) { - String msg = "multi insert not provided" ; - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - //插入语句必须提供列结构,因为MyCat默认对于表结构无感知 - if(valuesIndex + "VALUES".length() <= firstLeftBracketIndex) { - throw new SQLSyntaxErrorException("insert must provide ColumnList"); - } - List> vauleList = parseSqlValue(origSQL , valuesIndex); - //两种情况处理 1 有主键的 id ,但是值为null 进行改下 - // 2 没有主键的 需要插入 进行改写 - - //如果主键不在插入语句的fields中,则需要进一步处理 - boolean processedInsert= false; - int pkStart = isPKInFields(origSQL,primaryKey,firstLeftBracketIndex,firstRightBracketIndex); - - - if(pkStart == -1){ - processedInsert = true; - handleBatchInsert(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleList); - } else { - //判断 主键id的值是否为null - if(pkStart != -1) { - String subPrefix = origSQL.substring(0, pkStart); - char c; - int pkIndex = 0; - for(int index = 0, len = subPrefix.length(); index < len; index++) { - c = subPrefix.charAt(index); - if(c == ',') { - pkIndex ++; - } - } - processedInsert = handleBatchInsertWithPK(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleList , pkIndex); - } - } - return processedInsert; - } - - private static boolean handleBatchInsertWithPK(ServerConnection sc, SchemaConfig schema, int sqlType, - String origSQL, int valuesIndex, String tableName, String primaryKey, List> vauleList, - int pkIndex) { - boolean processedInsert = false; -// final String pk = "\\("+primaryKey+","; - final String mycatSeqPrefix = "next value for MYCATSEQ_"+tableName.toUpperCase() ; - - /*"VALUES".length() ==6 */ - String prefix = origSQL.substring(0, valuesIndex + 6); -// - - StringBuilder sb = new StringBuilder(""); - for(List list : vauleList) { - sb.append("("); - String pkValue = list.get(pkIndex).trim().toLowerCase(); - //null值替换为 next value for MYCATSEQ_tableName - if("null".equals(pkValue.trim())) { - list.set(pkIndex, mycatSeqPrefix); - processedInsert = true; - } - for(String val : list) { - sb.append(val).append(","); - } - sb.setCharAt(sb.length() - 1, ')'); - sb.append(","); - } - sb.setCharAt(sb.length() - 1, ' ');; - if(processedInsert) { - processSQL(sc, schema,prefix+sb.toString(), sqlType); - - } - return processedInsert; - } - - public static List handleBatchInsert(String origSQL, int valuesIndex) { - List handledSQLs = new LinkedList<>(); - String prefix = origSQL.substring(0, valuesIndex + "VALUES".length()); - String values = origSQL.substring(valuesIndex + "VALUES".length()); - int flag = 0; - StringBuilder currentValue = new StringBuilder(); - currentValue.append(prefix); - for (int i = 0; i < values.length(); i++) { - char j = values.charAt(i); - if (j == '(' && flag == 0) { - flag = 1; - currentValue.append(j); - } else if (j == '\"' && flag == 1) { - flag = 2; - currentValue.append(j); - } else if (j == '\'' && flag == 1) { - flag = 3; - currentValue.append(j); - } else if (j == '\\' && flag == 2) { - flag = 4; - currentValue.append(j); - } else if (j == '\\' && flag == 3) { - flag = 5; - currentValue.append(j); - } else if (flag == 4) { - flag = 2; - currentValue.append(j); - } else if (flag == 5) { - flag = 3; - currentValue.append(j); - } else if (j == '\"' && flag == 2) { - flag = 1; - currentValue.append(j); - } else if (j == '\'' && flag == 3) { - flag = 1; - currentValue.append(j); - } else if (j == ')' && flag == 1) { - flag = 0; - currentValue.append(j); - handledSQLs.add(currentValue.toString()); - currentValue = new StringBuilder(); - currentValue.append(prefix); - } else if (j == ',' && flag == 0) { - continue; - } else { - currentValue.append(j); - } - } - return handledSQLs; - } - /** - * 对于插入的sql : "insert into hotnews(title,name) values('test1',\"name\"),('(test)',\"(test)\"),('\\\"',\"\\'\"),(\")\",\"\\\"\\')\")": - * 需要返回结果: - *[[ 'test1', "name"], - * ['(test)', "(test)"], - * ['\"', "\'"], - * [")", "\"\')"], - * [ 1, null] - * 值结果的解析 - */ - public static List> parseSqlValue(String origSQL,int valuesIndex ) { - List> valueArray = new ArrayList<>(); - String valueStr = origSQL.substring(valuesIndex + 6);// 6 values 长度为6 - String preStr = origSQL.substring(0, valuesIndex );// 6 values 长度为6 - int pos = 0 ; - int flag = -1; - int len = valueStr.length(); - StringBuilder currentValue = new StringBuilder(); -// int colNum = 2; // - char c ; - List curList = new ArrayList<>(); - for( ;pos < len; pos ++) { - c = valueStr.charAt(pos); - if(flag == 1 || flag == 2) { - currentValue.append(c); - if(c == '\\') { - char nextCode = valueStr.charAt(pos + 1); - if(nextCode == '\'' || nextCode == '\"') { - currentValue.append(nextCode); - pos++; - continue; - } - } - if(c == '\"' && flag == 1) { - flag = 0; - continue; - } - if(c == '\'' && flag == 2) { - flag = 0; - continue; - } - } else if(c == '\"'){ - currentValue.append(c); - flag = 1; - } else if (c == '\'') { - currentValue.append(c); - flag = 2; - } else if (c == '(') { - curList = new ArrayList<>(); - flag = 0; - } else if(flag == 4 ) { - if(c == ',') { - flag = 0; - continue; - } - } else if(c == ',') { -// System.out.println(currentValue); - curList.add(currentValue.toString()); - currentValue.delete(0, currentValue.length()); - } else if(c == ')'){ - flag = 4; -// System.out.println(currentValue); - curList.add(currentValue.toString()); - currentValue.delete(0, currentValue.length()); - valueArray.add(curList); - } else { - currentValue.append(c); - } - } - return valueArray; - } - - /** - * 对于主键不在插入语句的fields中的SQL,需要改写。比如hotnews主键为id,插入语句为: - * insert into hotnews(title) values('aaa'); - * 需要改写成: - * insert into hotnews(id, title) values(next value for MYCATSEQ_hotnews,'aaa'); - */ - public static void handleBatchInsert(ServerConnection sc, SchemaConfig schema, - int sqlType,String origSQL, int valuesIndex,String tableName, String primaryKey , List> vauleList) { - - final String pk = "\\("+primaryKey+","; - final String mycatSeqPrefix = "(next value for MYCATSEQ_"+tableName.toUpperCase()+""; - - /*"VALUES".length() ==6 */ - String prefix = origSQL.substring(0, valuesIndex + 6); -// - prefix = prefix.replaceFirst("\\(", pk); - - StringBuilder sb = new StringBuilder(""); - for(List list : vauleList) { - sb.append(mycatSeqPrefix); - for(String val : list) { - sb.append(",").append(val); - } - sb.append("),"); - } - sb.setCharAt(sb.length() - 1, ' ');; - processSQL(sc, schema,prefix+sb.toString(), sqlType); - } -// /** -// * 对于主键不在插入语句的fields中的SQL,需要改写。比如hotnews主键为id,插入语句为: -// * insert into hotnews(title) values('aaa'); -// * 需要改写成: -// * insert into hotnews(id, title) values(next value for MYCATSEQ_hotnews,'aaa'); -// */ -// public static void handleBatchInsert(ServerConnection sc, SchemaConfig schema, -// int sqlType,String origSQL, int valuesIndex,String tableName, String primaryKey) { -// -// final String pk = "\\("+primaryKey+","; -// final String mycatSeqPrefix = "(next value for MYCATSEQ_"+tableName.toUpperCase()+","; -// -// /*"VALUES".length() ==6 */ -// String prefix = origSQL.substring(0, valuesIndex + 6); -// String values = origSQL.substring(valuesIndex + 6); -// -// prefix = prefix.replaceFirst("\\(", pk); -// values = values.replaceFirst("\\(", mycatSeqPrefix); -// values =Pattern.compile(",\\s*\\(").matcher(values).replaceAll(","+mycatSeqPrefix); -// processSQL(sc, schema,prefix+values, sqlType); -// } - - - public static RouteResultset routeToMultiNode(boolean cache,RouteResultset rrs, Collection dataNodes, String stmt) { - RouteResultsetNode[] nodes = new RouteResultsetNode[dataNodes.size()]; - int i = 0; - RouteResultsetNode node; - for (String dataNode : dataNodes) { - node = new RouteResultsetNode(dataNode, rrs.getSqlType(), stmt); - node.setSource(rrs); - if(rrs.getDataNodeSlotMap().containsKey(dataNode)){ - node.setSlot(rrs.getDataNodeSlotMap().get(dataNode)); - } - if (rrs.getCanRunInReadDB() != null) { - node.setCanRunInReadDB(rrs.getCanRunInReadDB()); - } - if(rrs.getRunOnSlave() != null){ - nodes[0].setRunOnSlave(rrs.getRunOnSlave()); - } - nodes[i++] = node; - } - rrs.setCacheAble(cache); - rrs.setNodes(nodes); - return rrs; - } - - public static RouteResultset routeToMultiNode(boolean cache, RouteResultset rrs, Collection dataNodes, - String stmt, boolean isGlobalTable) { - - rrs = routeToMultiNode(cache, rrs, dataNodes, stmt); - rrs.setGlobalTable(isGlobalTable); - return rrs; - } - - public static void routeForTableMeta(RouteResultset rrs, - SchemaConfig schema, String tableName, String sql) { - String dataNode = null; - if (isNoSharding(schema,tableName)) {//不分库的直接从schema中获取dataNode - dataNode = schema.getDataNode(); - } else { - dataNode = getMetaReadDataNode(schema, tableName); - } - - RouteResultsetNode[] nodes = new RouteResultsetNode[1]; - nodes[0] = new RouteResultsetNode(dataNode, rrs.getSqlType(), sql); - nodes[0].setSource(rrs); - if(rrs.getDataNodeSlotMap().containsKey(dataNode)){ - nodes[0].setSlot(rrs.getDataNodeSlotMap().get(dataNode)); - } - if (rrs.getCanRunInReadDB() != null) { - nodes[0].setCanRunInReadDB(rrs.getCanRunInReadDB()); - } - if(rrs.getRunOnSlave() != null){ - nodes[0].setRunOnSlave(rrs.getRunOnSlave()); - } - rrs.setNodes(nodes); - } - - /** - * 根据表名随机获取一个节点 - * - * @param schema 数据库名 - * @param table 表名 - * @return 数据节点 - * @author mycat - */ - private static String getMetaReadDataNode(SchemaConfig schema, - String table) { - // Table名字被转化为大写的,存储在schema - table = table.toUpperCase(); - String dataNode = null; - Map tables = schema.getTables(); - TableConfig tc; - if (tables != null && (tc = tables.get(table)) != null) { - dataNode = getAliveRandomDataNode(tc); - } - return dataNode; - } - - /** - * 解决getRandomDataNode方法获取错误节点的问题. - * @param tc - * @return - */ - private static String getAliveRandomDataNode(TableConfig tc) { - List randomDns = (List)tc.getDataNodes().clone(); - - MycatConfig mycatConfig = MycatServer.getInstance().getConfig(); - if (mycatConfig != null) { - Collections.shuffle(randomDns); - for (String randomDn : randomDns) { - PhysicalDBNode physicalDBNode = mycatConfig.getDataNodes().get(randomDn); - if (physicalDBNode != null) { - if (physicalDBNode.getDbPool().getSource().isAlive()) { - for (PhysicalDBPool pool : MycatServer.getInstance().getConfig().getDataHosts().values()) { - PhysicalDatasource source = pool.getSource(); - if (source.getHostConfig().containDataNode(randomDn) && pool.getSource().isAlive()) { - return randomDn; - } - } - } - } - } - } - - // all fail return default - return tc.getRandomDataNode(); - } - - @Deprecated - private static String getRandomDataNode(TableConfig tc) { - //写节点不可用,意味着读节点也不可用。 - //直接使用下一个 dataHost - String randomDn = tc.getRandomDataNode(); - MycatConfig mycatConfig = MycatServer.getInstance().getConfig(); - if (mycatConfig != null) { - PhysicalDBNode physicalDBNode = mycatConfig.getDataNodes().get(randomDn); - if (physicalDBNode != null) { - if (physicalDBNode.getDbPool().getSource().isAlive()) { - for (PhysicalDBPool pool : MycatServer.getInstance() - .getConfig() - .getDataHosts() - .values()) { - if (pool.getSource().getHostConfig().containDataNode(randomDn)) { - continue; - } - - if (pool.getSource().isAlive()) { - return pool.getSource().getHostConfig().getRandomDataNode(); - } - } - } - } - } - - //all fail return default - return randomDn; - } - - /** - * 根据 ER分片规则获取路由集合 - * - * @param stmt 执行的语句 - * @param rrs 数据路由集合 - * @param tc 表实体 - * @param joinKeyVal 连接属性 - * @return RouteResultset(数据路由集合) * - * @throws SQLNonTransientException,IllegalShardingColumnValueException - * @author mycat - */ - - public static RouteResultset routeByERParentKey(ServerConnection sc,SchemaConfig schema, - int sqlType,String stmt, - RouteResultset rrs, TableConfig tc, String joinKeyVal) - throws SQLNonTransientException { - - // only has one parent level and ER parent key is parent - // table's partition key - if (tc.isSecondLevel() - //判断是否为二级子表(父表不再有父表) - && tc.getParentTC().getPartitionColumn() - .equals(tc.getParentKey())) { // using - // parent - // rule to - // find - // datanode - Set parentColVal = new HashSet(1); - ColumnRoutePair pair = new ColumnRoutePair(joinKeyVal); - parentColVal.add(pair); - Set dataNodeSet = ruleCalculate(tc.getParentTC(), parentColVal,rrs.getDataNodeSlotMap()); - if (dataNodeSet.isEmpty() || dataNodeSet.size() > 1) { - throw new SQLNonTransientException( - "parent key can't find valid datanode ,expect 1 but found: " - + dataNodeSet.size()); - } - String dn = dataNodeSet.iterator().next(); - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("found partion node (using parent partion rule directly) for child table to insert " - + dn + " sql :" + stmt); - } - return RouterUtil.routeToSingleNode(rrs, dn, stmt); - } - return null; - } - - /** - * @return dataNodeIndex -> [partitionKeysValueTuple+] - */ - public static Set ruleByJoinValueCalculate(RouteResultset rrs, TableConfig tc, - Set colRoutePairSet) throws SQLNonTransientException { - - String joinValue = ""; - - if(colRoutePairSet.size() > 1) { - LOGGER.warn("joinKey can't have multi Value"); - } else { - Iterator it = colRoutePairSet.iterator(); - ColumnRoutePair joinCol = it.next(); - joinValue = joinCol.colValue; - } - - Set retNodeSet = new LinkedHashSet(); - - Set nodeSet; - if (tc.isSecondLevel() - && tc.getParentTC().getPartitionColumn() - .equals(tc.getParentKey())) { // using - // parent - // rule to - // find - // datanode - - nodeSet = ruleCalculate(tc.getParentTC(),colRoutePairSet,rrs.getDataNodeSlotMap()); - if (nodeSet.isEmpty()) { - throw new SQLNonTransientException( - "parent key can't find valid datanode ,expect 1 but found: " - + nodeSet.size()); - } - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("found partion node (using parent partion rule directly) for child table to insert " - + nodeSet + " sql :" + rrs.getStatement()); - } - retNodeSet.addAll(nodeSet); - -// for(ColumnRoutePair pair : colRoutePairSet) { -// nodeSet = ruleCalculate(tc.getParentTC(),colRoutePairSet); -// if (nodeSet.isEmpty() || nodeSet.size() > 1) {//an exception would be thrown, if sql was executed on more than on sharding -// throw new SQLNonTransientException( -// "parent key can't find valid datanode ,expect 1 but found: " -// + nodeSet.size()); -// } -// String dn = nodeSet.iterator().next(); -// if (LOGGER.isDebugEnabled()) { -// LOGGER.debug("found partion node (using parent partion rule directly) for child table to insert " -// + dn + " sql :" + rrs.getStatement()); -// } -// retNodeSet.addAll(nodeSet); -// } - return retNodeSet; - } else { - retNodeSet.addAll(tc.getParentTC().getDataNodes()); - } - - return retNodeSet; - } - - - /** - * @return dataNodeIndex -> [partitionKeysValueTuple+] - */ - public static Set ruleCalculate(TableConfig tc, - Set colRoutePairSet,Map dataNodeSlotMap) { - Set routeNodeSet = new LinkedHashSet(); - String col = tc.getRule().getColumn(); - RuleConfig rule = tc.getRule(); - AbstractPartitionAlgorithm algorithm = rule.getRuleAlgorithm(); - for (ColumnRoutePair colPair : colRoutePairSet) { - if (colPair.colValue != null) { - Integer nodeIndx = algorithm.calculate(colPair.colValue); - if (nodeIndx == null) { - throw new IllegalArgumentException( - "can't find datanode for sharding column:" + col - + " val:" + colPair.colValue); - } else { - String dataNode = tc.getDataNodes().get(nodeIndx); - routeNodeSet.add(dataNode); - if(algorithm instanceof SlotFunction) { - dataNodeSlotMap.put(dataNode,((SlotFunction) algorithm).slotValue()); - } - colPair.setNodeId(nodeIndx); - } - } else if (colPair.rangeValue != null) { - Integer[] nodeRange = algorithm.calculateRange( - String.valueOf(colPair.rangeValue.beginValue), - String.valueOf(colPair.rangeValue.endValue)); - if (nodeRange != null) { - /** - * 不能确认 colPair的 nodeid是否会有其它影响 - */ - if (nodeRange.length == 0) { - routeNodeSet.addAll(tc.getDataNodes()); - } else { - ArrayList dataNodes = tc.getDataNodes(); - String dataNode = null; - for (Integer nodeId : nodeRange) { - dataNode = dataNodes.get(nodeId); - if(algorithm instanceof SlotFunction) { - dataNodeSlotMap.put(dataNode,((SlotFunction) algorithm).slotValue()); - } - routeNodeSet.add(dataNode); - } - } - } - } - - } - return routeNodeSet; - } - - /** - * 多表路由 - */ - public static RouteResultset tryRouteForTables(SchemaConfig schema, DruidShardingParseInfo ctx, - RouteCalculateUnit routeUnit, RouteResultset rrs, boolean isSelect, LayerCachePool cachePool) - throws SQLNonTransientException { - - List tables = ctx.getTables(); - - //每个表对应的路由映射 - Map> tablesRouteMap = new HashMap>(); - - //为全局表和单库表找路由 - for(String tableName : tables) { - - TableConfig tableConfig = schema.getTables().get(tableName.toUpperCase()); - - if(tableConfig == null) { - //add 如果表读取不到则先将表名从别名中读取转化后再读取 - String alias = ctx.getTableAliasMap().get(tableName); - if(!StringUtil.isEmpty(alias)){ - tableConfig = schema.getTables().get(alias.toUpperCase()); - } - - if(tableConfig == null){ - String msg = "can't find table define in schema "+ tableName + " schema:" + schema.getName(); - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - - } - if(tableConfig.isGlobalTable()) {//全局表 - if(tablesRouteMap.get(tableName) == null) { - tablesRouteMap.put(tableName, new HashSet()); - } - tablesRouteMap.get(tableName).addAll(tableConfig.getDataNodes()); - } else if(tablesRouteMap.get(tableName) == null) { //余下的表都是单库表 - tablesRouteMap.put(tableName, new HashSet()); - tablesRouteMap.get(tableName).addAll(tableConfig.getDataNodes()); - } - - if(tableConfig.getDistTables().size() > 0) { - Map> subTablesmap = rrs.getSubTableMaps(); - if (subTablesmap == null) { - subTablesmap = Maps.newHashMap(); - rrs.setSubTableMaps(subTablesmap); - } - - subTablesmap.put(tableName.toUpperCase(), tableConfig.getDistTables()); - } - } - - if(schema.isNoSharding()||(tables.size() >= 1&&isNoSharding(schema,tables.get(0)))) { - return routeToSingleNode(rrs, schema.getDataNode(), ctx.getSql()); - } - - //只有一个表的 - if(tables.size() == 1) { - return RouterUtil.tryRouteForOneTable(schema, ctx, routeUnit, tables.get(0), rrs, isSelect, cachePool); - } - - Set retNodesSet = new HashSet(); - - //分库解析信息不为空 - Map>> tablesAndConditions = routeUnit.getTablesAndConditions(); - if(tablesAndConditions != null && tablesAndConditions.size() > 0) { - //为分库表找路由 - RouterUtil.findRouteWithcConditionsForTables(schema, rrs, tablesAndConditions, tablesRouteMap, ctx.getSql(), cachePool, isSelect); - if(rrs.isFinishedRoute()) { - return rrs; - } - } - - - boolean isFirstAdd = true; - for(Map.Entry> entry : tablesRouteMap.entrySet()) { - if(entry.getValue() == null || entry.getValue().size() == 0) { - throw new SQLNonTransientException("parent key can't find any valid datanode "); - } else { - if(isFirstAdd) { - retNodesSet.addAll(entry.getValue()); - isFirstAdd = false; - } else { - retNodesSet.retainAll(entry.getValue()); - if(retNodesSet.size() == 0) {//两个表的路由无交集 - String errMsg = "invalid route in sql, multi tables found but datanode has no intersection " - + " sql:" + ctx.getSql(); - LOGGER.warn(errMsg); - throw new SQLNonTransientException(errMsg); - } - } - } - } - - if(retNodesSet != null && retNodesSet.size() > 0) { - String tableName = tables.get(0); - TableConfig tableConfig = schema.getTables().get(tableName.toUpperCase()); - if(tableConfig.isDistTable()){ - routeToDistTableNode(schema, rrs, ctx.getSql(), tablesAndConditions, cachePool, isSelect); - return rrs; - } - - if(retNodesSet.size() > 1 && isAllGlobalTable(ctx, schema)) { - // mulit routes ,not cache route result - if (isSelect) { - rrs.setCacheAble(false); - ArrayList retNodeList = new ArrayList(retNodesSet); - Collections.shuffle(retNodeList);//by kaiz : add shuffle - routeToSingleNode(rrs, retNodeList.get(0), ctx.getSql()); - } - else {//delete 删除全局表的记录 - routeToMultiNode(isSelect, rrs, retNodesSet, ctx.getSql(),true); - } - - } else { - routeToMultiNode(isSelect, rrs, retNodesSet, ctx.getSql()); - } - - } - return rrs; - - } - - - /** - * - * 单表路由 - */ - public static RouteResultset tryRouteForOneTable(SchemaConfig schema, DruidShardingParseInfo ctx, - RouteCalculateUnit routeUnit, String tableName, RouteResultset rrs, boolean isSelect, - LayerCachePool cachePool) throws SQLNonTransientException { - - if (isNoSharding(schema, tableName)) { - return routeToSingleNode(rrs, schema.getDataNode(), ctx.getSql()); - } - - TableConfig tc = schema.getTables().get(tableName); - if(tc == null) { - String msg = "can't find table define in schema " + tableName + " schema:" + schema.getName(); - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - - if(tc.isDistTable()){ - return routeToDistTableNode(schema,rrs,ctx.getSql(), routeUnit.getTablesAndConditions(), cachePool,isSelect); - } - - if(tc.isGlobalTable()) {//全局表 - if(isSelect) { - // global select ,not cache route result - rrs.setCacheAble(false); - return routeToSingleNode(rrs, getAliveRandomDataNode(tc)/*getRandomDataNode(tc)*/, ctx.getSql()); - } else {//insert into 全局表的记录 - return routeToMultiNode(false, rrs, tc.getDataNodes(), ctx.getSql(),true); - } - } else {//单表或者分库表 - if (!checkRuleRequired(schema, ctx, routeUnit, tc)) { - throw new IllegalArgumentException("route rule for table " - + tc.getName() + " is required: " + ctx.getSql()); - - } - if(tc.getPartitionColumn() == null && !tc.isSecondLevel()) {//单表且不是childTable -// return RouterUtil.routeToSingleNode(rrs, tc.getDataNodes().get(0),ctx.getSql()); - return routeToMultiNode(rrs.isCacheAble(), rrs, tc.getDataNodes(), ctx.getSql()); - } else { - //每个表对应的路由映射 - Map> tablesRouteMap = new HashMap>(); - if(routeUnit.getTablesAndConditions() != null && routeUnit.getTablesAndConditions().size() > 0) { - RouterUtil.findRouteWithcConditionsForTables(schema, rrs, routeUnit.getTablesAndConditions(), tablesRouteMap, ctx.getSql(), cachePool, isSelect); - if(rrs.isFinishedRoute()) { - return rrs; - } - } - - if(tablesRouteMap.get(tableName) == null) { - return routeToMultiNode(rrs.isCacheAble(), rrs, tc.getDataNodes(), ctx.getSql()); - } else { - return routeToMultiNode(rrs.isCacheAble(), rrs, tablesRouteMap.get(tableName), ctx.getSql()); - } - } - } - } - - private static RouteResultset routeToDistTableNode(SchemaConfig schema, RouteResultset rrs, - String orgSql, Map>> tablesAndConditions, - LayerCachePool cachePool, boolean isSelect) throws SQLNonTransientException { - List tables = rrs.getTables(); - - String tableName = tables.get(0); - TableConfig tableConfig = schema.getTables().get(tableName); - if(tableConfig == null) { - String msg = "can't find table define in schema " + tableName + " schema:" + schema.getName(); - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - if(tableConfig.isGlobalTable()){ - String msg = "can't suport district table " + tableName + " schema:" + schema.getName() + " for global table "; - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - String partionCol = tableConfig.getPartitionColumn(); -// String primaryKey = tableConfig.getPrimaryKey(); - boolean isLoadData=false; - - Set tablesRouteSet = new HashSet(); - - List dataNodes = tableConfig.getDataNodes(); - if(dataNodes.size()>1){ - String msg = "can't suport district table " + tableName + " schema:" + schema.getName() + " for mutiple dataNode " + dataNodes; - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - String dataNode = dataNodes.get(0); - - RouteResultsetNode[] nodes = null; - //主键查找缓存暂时不实现 - if(tablesAndConditions.isEmpty()){ - List subTables = tableConfig.getDistTables(); - tablesRouteSet.addAll(subTables); - - nodes = getNode(rrs, orgSql, tablesRouteSet, dataNode, false, tableName); - } else { - - for(Map.Entry>> entry : tablesAndConditions.entrySet()) { - boolean isFoundPartitionValue = partionCol != null && entry.getValue().get(partionCol) != null; - Map> columnsMap = entry.getValue(); - - Set partitionValue = columnsMap.get(partionCol); - if(partitionValue == null || partitionValue.size() == 0) { - tablesRouteSet.addAll(tableConfig.getDistTables()); - } else { - for(ColumnRoutePair pair : partitionValue) { - AbstractPartitionAlgorithm algorithm = tableConfig.getRule().getRuleAlgorithm(); - if(pair.colValue != null) { - Integer tableIndex = algorithm.calculate(pair.colValue); - if(tableIndex == null) { - String msg = "can't find any valid datanode :" + tableConfig.getName() - + " -> " + tableConfig.getPartitionColumn() + " -> " + pair.colValue; - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - String subTable = tableConfig.getDistTables().get(tableIndex); - if(subTable != null) { - tablesRouteSet.add(subTable); - if(algorithm instanceof SlotFunction){ - rrs.getDataNodeSlotMap().put(subTable,((SlotFunction) algorithm).slotValue()); - } - } - } - if(pair.rangeValue != null) { - Integer[] tableIndexs = algorithm - .calculateRange(pair.rangeValue.beginValue.toString(), pair.rangeValue.endValue.toString()); - for(Integer idx : tableIndexs) { - String subTable = tableConfig.getDistTables().get(idx); - if(subTable != null) { - tablesRouteSet.add(subTable); - if(algorithm instanceof SlotFunction){ - rrs.getDataNodeSlotMap().put(subTable,((SlotFunction) algorithm).slotValue()); - } - } - } - } - } - } - } - - nodes = getNode(rrs, orgSql, tablesRouteSet, dataNode, true, tableName); - } - - rrs.setNodes(nodes); - rrs.setSubTables(tablesRouteSet); - rrs.setFinishedRoute(true); - - return rrs; - } - - private static RouteResultsetNode[] getNode(RouteResultset rrs, String orgSql, Set tablesRouteSet, - String dataNode, boolean is, String tableName) { - Object[] subTables = tablesRouteSet.toArray(); - RouteResultsetNode[] nodes = new RouteResultsetNode[subTables.length]; - Map dataNodeSlotMap= rrs.getDataNodeSlotMap(); - for(int i=0;i> subTableMaps = rrs.getSubTableMaps(); - if(subTableMaps != null) { - List list = subTableMaps.get(tableName); - int index = 0; - for (String subTable : list) { - if (table.equals(subTable)) { - break; - } - index++; - } - for (String tableSource : subTableMaps.keySet()) { - Map subTableNames = nodes[i].getSubTableNames(); - if (subTableNames == null) { - subTableNames = Maps.newHashMap(); - nodes[i].setSubTableNames(subTableNames); - } - if (tableSource.equals(tableName)) { - subTableNames.put(tableSource, table); - } else { - subTableNames.put(tableSource, subTableMaps.get(tableSource).get(index)); - } - - } - } - } else { - Map> subTableMaps = rrs.getSubTableMaps(); - if(subTableMaps != null) { - for (String tableSource : subTableMaps.keySet()) { - Map subTableNames = nodes[i].getSubTableNames(); - if (subTableNames == null) { - subTableNames = Maps.newHashMap(); - nodes[i].setSubTableNames(subTableNames); - } - subTableNames.put(tableSource, subTableMaps.get(tableSource).get(i)); - } - } - } - - nodes[i].setSource(rrs); - if(rrs.getDataNodeSlotMap().containsKey(dataNode)){ - nodes[i].setSlot(rrs.getDataNodeSlotMap().get(dataNode)); - } - if (rrs.getCanRunInReadDB() != null) { - nodes[i].setCanRunInReadDB(rrs.getCanRunInReadDB()); - } - if(dataNodeSlotMap.containsKey(table)) { - nodes[i].setSlot(dataNodeSlotMap.get(table)); - } - if(rrs.getRunOnSlave() != null){ - nodes[0].setRunOnSlave(rrs.getRunOnSlave()); - } - } - return nodes; - } - - /** - * 处理分库表路由 - */ - public static void findRouteWithcConditionsForTables(SchemaConfig schema, RouteResultset rrs, - Map>> tablesAndConditions, - Map> tablesRouteMap, String sql, LayerCachePool cachePool, boolean isSelect) - throws SQLNonTransientException { - - //为分库表找路由 - for(Map.Entry>> entry : tablesAndConditions.entrySet()) { - String tableName = entry.getKey().toUpperCase(); - TableConfig tableConfig = schema.getTables().get(tableName); - if(tableConfig == null) { - String msg = "can't find table define in schema " - + tableName + " schema:" + schema.getName(); - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - if(tableConfig.getDistTables()!=null && tableConfig.getDistTables().size()>0){ - routeToDistTableNode(schema,rrs,sql, tablesAndConditions, cachePool,isSelect); - } - //全局表或者不分库的表略过(全局表后面再计算) - if(tableConfig.isGlobalTable() || schema.getTables().get(tableName).getDataNodes().size() == 1) { - continue; - } else {//非全局表:分库表、childTable、其他 - Map> columnsMap = entry.getValue(); - String joinKey = tableConfig.getJoinKey(); - String partionCol = tableConfig.getPartitionColumn(); - String primaryKey = tableConfig.getPrimaryKey(); - boolean isFoundPartitionValue = partionCol != null && entry.getValue().get(partionCol) != null; - boolean isLoadData=false; - if (LOGGER.isDebugEnabled() - && sql.startsWith(LoadData.loadDataHint)||rrs.isLoadData()) { - //由于load data一次会计算很多路由数据,如果输出此日志会极大降低load data的性能 - isLoadData=true; - } - if(entry.getValue().get(primaryKey) != null && entry.getValue().size() == 1&&!isLoadData) - {//主键查找 - // try by primary key if found in cache - Set primaryKeyPairs = entry.getValue().get(primaryKey); - if (primaryKeyPairs != null) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("try to find cache by primary key "); - } - String tableKey = schema.getName() + '_' + tableName; - boolean allFound = true; - for (ColumnRoutePair pair : primaryKeyPairs) {//可能id in(1,2,3)多主键 - String cacheKey = pair.colValue; - String dataNode = (String) cachePool.get(tableKey, cacheKey); - if (dataNode == null) { - allFound = false; - continue; - } else { - if(tablesRouteMap.get(tableName) == null) { - tablesRouteMap.put(tableName, new HashSet()); - } - tablesRouteMap.get(tableName).add(dataNode); - continue; - } - } - if (!allFound) { - // need cache primary key ->datanode relation - if (isSelect && tableConfig.getPrimaryKey() != null) { - rrs.setPrimaryKey(tableKey + '.' + tableConfig.getPrimaryKey()); - } - } else {//主键缓存中找到了就执行循环的下一轮 - continue; - } - } - } - if (isFoundPartitionValue) {//分库表 - Set partitionValue = columnsMap.get(partionCol); - if(partitionValue == null || partitionValue.size() == 0) { - if(tablesRouteMap.get(tableName) == null) { - tablesRouteMap.put(tableName, new HashSet()); - } - tablesRouteMap.get(tableName).addAll(tableConfig.getDataNodes()); - } else { - for(ColumnRoutePair pair : partitionValue) { - AbstractPartitionAlgorithm algorithm = tableConfig.getRule().getRuleAlgorithm(); - if(pair.colValue != null) { - Integer nodeIndex = algorithm.calculate(pair.colValue); - if(nodeIndex == null) { - String msg = "can't find any valid datanode :" + tableConfig.getName() - + " -> " + tableConfig.getPartitionColumn() + " -> " + pair.colValue; - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - - ArrayList dataNodes = tableConfig.getDataNodes(); - String node; - if (nodeIndex >=0 && nodeIndex < dataNodes.size()) { - node = dataNodes.get(nodeIndex); - - } else { - node = null; - String msg = "Can't find a valid data node for specified node index :" - + tableConfig.getName() + " -> " + tableConfig.getPartitionColumn() - + " -> " + pair.colValue + " -> " + "Index : " + nodeIndex; - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - if(node != null) { - if(tablesRouteMap.get(tableName) == null) { - tablesRouteMap.put(tableName, new HashSet()); - } - if(algorithm instanceof SlotFunction){ - rrs.getDataNodeSlotMap().put(node,((SlotFunction) algorithm).slotValue()); - } - tablesRouteMap.get(tableName).add(node); - } - } - if(pair.rangeValue != null) { - Integer[] nodeIndexs = algorithm - .calculateRange(pair.rangeValue.beginValue.toString(), pair.rangeValue.endValue.toString()); - ArrayList dataNodes = tableConfig.getDataNodes(); - String node; - for(Integer idx : nodeIndexs) { - if (idx >= 0 && idx < dataNodes.size()) { - node = dataNodes.get(idx); - } else { - String msg = "Can't find valid data node(s) for some of specified node indexes :" - + tableConfig.getName() + " -> " + tableConfig.getPartitionColumn(); - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - if(node != null) { - if(tablesRouteMap.get(tableName) == null) { - tablesRouteMap.put(tableName, new HashSet()); - } - if(algorithm instanceof SlotFunction){ - rrs.getDataNodeSlotMap().put(node,((SlotFunction) algorithm).slotValue()); - } - tablesRouteMap.get(tableName).add(node); - - } - } - } - } - } - } else if(joinKey != null && columnsMap.get(joinKey) != null && columnsMap.get(joinKey).size() != 0) {//childTable (如果是select 语句的父子表join)之前要找到root table,将childTable移除,只留下root table - Set joinKeyValue = columnsMap.get(joinKey); - - Set dataNodeSet = ruleByJoinValueCalculate(rrs, tableConfig, joinKeyValue); - - if (dataNodeSet.isEmpty()) { - throw new SQLNonTransientException( - "parent key can't find any valid datanode "); - } - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("found partion nodes (using parent partion rule directly) for child table to update " - + Arrays.toString(dataNodeSet.toArray()) + " sql :" + sql); - } - if (dataNodeSet.size() > 1) { - routeToMultiNode(rrs.isCacheAble(), rrs, dataNodeSet, sql); - rrs.setFinishedRoute(true); - return; - } else { - rrs.setCacheAble(true); - routeToSingleNode(rrs, dataNodeSet.iterator().next(), sql); - return; - } - - } else { - //没找到拆分字段,该表的所有节点都路由 - if(tablesRouteMap.get(tableName) == null) { - tablesRouteMap.put(tableName, new HashSet()); - } - boolean isSlotFunction= tableConfig.getRule() != null && tableConfig.getRule().getRuleAlgorithm() instanceof SlotFunction; - if(isSlotFunction){ - for (String dn : tableConfig.getDataNodes()) { - rrs.getDataNodeSlotMap().put(dn,-1); - } - } - tablesRouteMap.get(tableName).addAll(tableConfig.getDataNodes()); - } - } - } - } - - public static boolean isAllGlobalTable(DruidShardingParseInfo ctx, SchemaConfig schema) { - boolean isAllGlobal = false; - for(String table : ctx.getTables()) { - TableConfig tableConfig = schema.getTables().get(table); - if(tableConfig!=null && tableConfig.isGlobalTable()) { - isAllGlobal = true; - } else { - return false; - } - } - return isAllGlobal; - } - - /** - * - * @param schema - * @param ctx - * @param tc - * @return true表示校验通过,false表示检验不通过 - */ - public static boolean checkRuleRequired(SchemaConfig schema, DruidShardingParseInfo ctx, RouteCalculateUnit routeUnit, TableConfig tc) { - if(!tc.isRuleRequired()) { - return true; - } - boolean hasRequiredValue = false; - String tableName = tc.getName(); - if(routeUnit.getTablesAndConditions().get(tableName) == null || routeUnit.getTablesAndConditions().get(tableName).size() == 0) { - hasRequiredValue = false; - } else { - for(Map.Entry> condition : routeUnit.getTablesAndConditions().get(tableName).entrySet()) { - - String colName = condition.getKey(); - //条件字段是拆分字段 - if(colName.equals(tc.getPartitionColumn())) { - hasRequiredValue = true; - break; - } - } - } - return hasRequiredValue; - } - - - /** - * 增加判断支持未配置分片的表走默认的dataNode - * @param schemaConfig - * @param tableName - * @return - */ - public static boolean isNoSharding(SchemaConfig schemaConfig, String tableName) { - // Table名字被转化为大写的,存储在schema - tableName = tableName.toUpperCase(); - if (schemaConfig.isNoSharding()) { - return true; - } - - if (schemaConfig.getDataNode() != null && !schemaConfig.getTables().containsKey(tableName)) { - return true; - } - - return false; - } - - /** - * 系统表判断,某些sql语句会查询系统表或者跟系统表关联 - * @author lian - * @date 2016年12月2日 - * @param tableName - * @return - */ - public static boolean isSystemSchema(String tableName) { - // 以information_schema, mysql开头的是系统表 - if (tableName.startsWith("INFORMATION_SCHEMA.") - || tableName.startsWith("MYSQL.") - || tableName.startsWith("PERFORMANCE_SCHEMA.")) { - return true; - } - - return false; - } - - /** - * 判断条件是否永真 - * @param expr - * @return - */ - public static boolean isConditionAlwaysTrue(SQLExpr expr) { - Object o = WallVisitorUtils.getValue(expr); - if(Boolean.TRUE.equals(o)) { - return true; - } - return false; - } - - /** - * 判断条件是否永假的 - * @param expr - * @return - */ - public static boolean isConditionAlwaysFalse(SQLExpr expr) { - Object o = WallVisitorUtils.getValue(expr); - if(Boolean.FALSE.equals(o)) { - return true; - } - return false; - } - - - /** - * 该方法,返回是否是ER子表 - * @param schema - * @param origSQL - * @param sc - * @return - * @throws SQLNonTransientException - * - * 备注说明: - * edit by ding.w at 2017.4.28, 主要处理 CLIENT_MULTI_STATEMENTS(insert into ; insert into)的情况 - * 目前仅支持mysql,并COM_QUERY请求包中的所有insert语句要么全部是er表,要么全部不是 - * - * - */ - public static boolean processERChildTable(final SchemaConfig schema, final String origSQL, - final ServerConnection sc) throws SQLNonTransientException { - - MySqlStatementParser parser = new MySqlStatementParser(origSQL); - List statements = parser.parseStatementList(); - - if(statements == null || statements.isEmpty() ) { - throw new SQLNonTransientException(String.format("无效的SQL语句:%s", origSQL)); - } - - - boolean erFlag = false; //是否是er表 - for(SQLStatement stmt : statements ) { - MySqlInsertStatement insertStmt = (MySqlInsertStatement) stmt; - String tableName = insertStmt.getTableName().getSimpleName().toUpperCase(); - final TableConfig tc = schema.getTables().get(tableName); - - if (null != tc && tc.isChildTable()) { - erFlag = true; - - String sql = insertStmt.toString(); - - final RouteResultset rrs = new RouteResultset(sql, ServerParse.INSERT); - String joinKey = tc.getJoinKey(); - //因为是Insert语句,用MySqlInsertStatement进行parse -// MySqlInsertStatement insertStmt = (MySqlInsertStatement) (new MySqlStatementParser(origSQL)).parseInsert(); - //判断条件完整性,取得解析后语句列中的joinkey列的index - int joinKeyIndex = getJoinKeyIndex(insertStmt.getColumns(), joinKey); - if (joinKeyIndex == -1) { - String inf = "joinKey not provided :" + tc.getJoinKey() + "," + insertStmt; - LOGGER.warn(inf); - throw new SQLNonTransientException(inf); - } - //子表不支持批量插入 - if (isMultiInsert(insertStmt)) { - String msg = "ChildTable multi insert not provided"; - LOGGER.warn(msg); - throw new SQLNonTransientException(msg); - } - //取得joinkey的值 - String joinKeyVal = insertStmt.getValues().getValues().get(joinKeyIndex).toString(); - //解决bug #938,当关联字段的值为char类型时,去掉前后"'" - String realVal = joinKeyVal; - if (joinKeyVal.startsWith("'") && joinKeyVal.endsWith("'") && joinKeyVal.length() > 2) { - realVal = joinKeyVal.substring(1, joinKeyVal.length() - 1); - } - - - - // try to route by ER parent partion key - //如果是二级子表(父表不再有父表),并且分片字段正好是joinkey字段,调用routeByERParentKey - RouteResultset theRrs = RouterUtil.routeByERParentKey(sc, schema, ServerParse.INSERT, sql, rrs, tc, realVal); - if (theRrs != null) { - boolean processedInsert=false; - //判断是否需要全局序列号 - if ( sc!=null && tc.isAutoIncrement()) { - String primaryKey = tc.getPrimaryKey(); - processedInsert=processInsert(sc,schema,ServerParse.INSERT,sql,tc.getName(),primaryKey); - } - if(processedInsert==false){ - rrs.setFinishedRoute(true); - sc.getSession2().execute(rrs, ServerParse.INSERT); - } - // return true; - //继续处理下一条 - continue; - } - - // route by sql query root parent's datanode - //如果不是二级子表或者分片字段不是joinKey字段结果为空,则启动异步线程去后台分片查询出datanode - //只要查询出上一级表的parentkey字段的对应值在哪个分片即可 - final String findRootTBSql = tc.getLocateRTableKeySql().toLowerCase() + joinKeyVal; - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("find root parent's node sql " + findRootTBSql); - } - - ListenableFuture listenableFuture = MycatServer.getInstance(). - getListeningExecutorService().submit(new Callable() { - @Override - public String call() throws Exception { - FetchStoreNodeOfChildTableHandler fetchHandler = new FetchStoreNodeOfChildTableHandler(); -// return fetchHandler.execute(schema.getName(), findRootTBSql, tc.getRootParent().getDataNodes()); - return fetchHandler.execute(schema.getName(), findRootTBSql, tc.getRootParent().getDataNodes(), sc); - } - }); - - - Futures.addCallback(listenableFuture, new FutureCallback() { - @Override - public void onSuccess(String result) { - //结果为空,证明上一级表中不存在那条记录,失败 - if (Strings.isNullOrEmpty(result)) { - StringBuilder s = new StringBuilder(); - LOGGER.warn(s.append(sc.getSession2()).append(origSQL).toString() + - " err:" + "can't find (root) parent sharding node for sql:" + origSQL); - if(!sc.isAutocommit()) { // 处于事务下失败, 必须回滚 - sc.setTxInterrupt("can't find (root) parent sharding node for sql:" + origSQL); - } - sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR, "can't find (root) parent sharding node for sql:" + origSQL); - return; - } - - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("found partion node for child table to insert " + result + " sql :" + origSQL); - } - //找到分片,进行插入(和其他的一样,需要判断是否需要全局自增ID) - boolean processedInsert=false; - if ( sc!=null && tc.isAutoIncrement()) { - try { - String primaryKey = tc.getPrimaryKey(); - processedInsert=processInsert(sc,schema,ServerParse.INSERT,origSQL,tc.getName(),primaryKey); - } catch (SQLNonTransientException e) { - LOGGER.warn("sequence processInsert error,",e); - sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR , "sequence processInsert error," + e.getMessage()); - } - } - if(processedInsert==false){ - RouteResultset executeRrs = RouterUtil.routeToSingleNode(rrs, result, origSQL); - sc.getSession2().execute(executeRrs, ServerParse.INSERT); - } - - } - - @Override - public void onFailure(Throwable t) { - StringBuilder s = new StringBuilder(); - LOGGER.warn(s.append(sc.getSession2()).append(origSQL).toString() + - " err:" + t.getMessage()); - sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR, t.getMessage() + " " + s.toString()); - } - }, MycatServer.getInstance(). - getListeningExecutorService()); - - } else if(erFlag) { - throw new SQLNonTransientException(String.format("%s包含不是ER分片的表", origSQL)); - } - } - - - return erFlag; - } - - /** - * 寻找joinKey的索引 - * - * @param columns - * @param joinKey - * @return -1表示没找到,>=0表示找到了 - */ - private static int getJoinKeyIndex(List columns, String joinKey) { - for (int i = 0; i < columns.size(); i++) { - String col = StringUtil.removeBackquote(columns.get(i).toString()).toUpperCase(); - if (col.equals(joinKey)) { - return i; - } - } - return -1; - } - - /** - * 是否为批量插入:insert into ...values (),()...或 insert into ...select..... - * - * @param insertStmt - * @return - */ - private static boolean isMultiInsert(MySqlInsertStatement insertStmt) { - return (insertStmt.getValuesList() != null && insertStmt.getValuesList().size() > 1) - || insertStmt.getQuery() != null; - } - -} +package io.mycat.route.util; + +import java.sql.SQLNonTransientException; +import java.sql.SQLSyntaxErrorException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.regex.Pattern; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.alibaba.druid.sql.ast.SQLExpr; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.ast.expr.SQLCharExpr; +import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; +import com.alibaba.druid.sql.ast.statement.SQLCharacterDataType; +import com.alibaba.druid.sql.ast.statement.SQLColumnDefinition; +import com.alibaba.druid.sql.ast.statement.SQLCreateTableStatement; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlCreateTableStatement; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; +import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; +import com.alibaba.druid.wall.spi.WallVisitorUtils; +import com.google.common.base.Strings; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; + +import io.mycat.MycatServer; +import io.mycat.backend.datasource.PhysicalDBNode; +import io.mycat.backend.datasource.PhysicalDBPool; +import io.mycat.backend.datasource.PhysicalDatasource; +import io.mycat.backend.mysql.nio.handler.FetchStoreNodeOfChildTableHandler; +import io.mycat.cache.LayerCachePool; +import io.mycat.config.ErrorCode; +import io.mycat.config.MycatConfig; +import io.mycat.config.model.SchemaConfig; +import io.mycat.config.model.TableConfig; +import io.mycat.config.model.rule.RuleConfig; +import io.mycat.route.RouteResultset; +import io.mycat.route.RouteResultsetNode; +import io.mycat.route.SessionSQLPair; +import io.mycat.route.function.AbstractPartitionAlgorithm; +import io.mycat.route.function.SlotFunction; +import io.mycat.route.parser.druid.DruidShardingParseInfo; +import io.mycat.route.parser.druid.RouteCalculateUnit; +import io.mycat.server.ServerConnection; +import io.mycat.server.parser.ServerParse; +import io.mycat.sqlengine.mpp.ColumnRoutePair; +import io.mycat.sqlengine.mpp.LoadData; +import io.mycat.util.StringUtil; + +/** + * 从ServerRouterUtil中抽取的一些公用方法,路由解析工具类 + * @author wang.dw + * + */ +public class RouterUtil { + + private static final Logger LOGGER = LoggerFactory.getLogger(RouterUtil.class); + + /** + * 移除执行语句中的数据库名 + * + * @param stmt 执行语句 + * @param schema 数据库名 + * @return 执行语句 + * @author mycat + * + * @modification 修正移除schema的方法 + * @date 2016/12/29 + * @modifiedBy Hash Zhang + * + */ + public static String removeSchema(String stmt, String schema) { + final String upStmt = stmt.toUpperCase(); + final String upSchema = schema.toUpperCase() + "."; + final String upSchema2 = new StringBuilder("`").append(schema.toUpperCase()).append("`.").toString(); + int strtPos = 0; + int indx = 0; + + int indx1 = upStmt.indexOf(upSchema, strtPos); + int indx2 = upStmt.indexOf(upSchema2, strtPos); + boolean flag = indx1 < indx2 ? indx1 == -1 : indx2 != -1; + indx = !flag ? indx1 > 0 ? indx1 : indx2 : indx2 > 0 ? indx2 : indx1; + if (indx < 0) { + return stmt; + } + + int firstE = upStmt.indexOf("'"); + int endE = upStmt.lastIndexOf("'"); + + StringBuilder sb = new StringBuilder(); + while (indx > 0) { + sb.append(stmt.substring(strtPos, indx)); + + if (flag) { + strtPos = indx + upSchema2.length(); + } else { + strtPos = indx + upSchema.length(); + } + if (indx > firstE && indx < endE && countChar(stmt, indx) % 2 == 1) { + sb.append(stmt.substring(indx, indx + schema.length() + 1)); + } + indx1 = upStmt.indexOf(upSchema, strtPos); + indx2 = upStmt.indexOf(upSchema2, strtPos); + flag = indx1 < indx2 ? indx1 == -1 : indx2 != -1; + indx = !flag ? indx1 > 0 ? indx1 : indx2 : indx2 > 0 ? indx2 : indx1; + } + sb.append(stmt.substring(strtPos)); + return sb.toString(); + } + + private static int countChar(String sql,int end) + { + int count=0; + boolean skipChar = false; + for (int i = 0; i < end; i++) { + if(sql.charAt(i)=='\'' && !skipChar) { + count++; + skipChar = false; + }else if( sql.charAt(i)=='\\'){ + skipChar = true; + }else{ + skipChar = false; + } + } + return count; + } + + /** + * 获取第一个节点作为路由 + * + * @param rrs 数据路由集合 + * @param dataNode 数据库所在节点 + * @param stmt 执行语句 + * @return 数据路由集合 + * + * @author mycat + */ + public static RouteResultset routeToSingleNode(RouteResultset rrs, + String dataNode, String stmt) { + if (dataNode == null) { + return rrs; + } + RouteResultsetNode[] nodes = new RouteResultsetNode[1]; + nodes[0] = new RouteResultsetNode(dataNode, rrs.getSqlType(), stmt);//rrs.getStatement() + nodes[0].setSource(rrs); + rrs.setNodes(nodes); + rrs.setFinishedRoute(true); + if(rrs.getDataNodeSlotMap().containsKey(dataNode)){ + nodes[0].setSlot(rrs.getDataNodeSlotMap().get(dataNode)); + } + if (rrs.getCanRunInReadDB() != null) { + nodes[0].setCanRunInReadDB(rrs.getCanRunInReadDB()); + } + if(rrs.getRunOnSlave() != null){ + nodes[0].setRunOnSlave(rrs.getRunOnSlave()); + } + + return rrs; + } + + + + /** + * 修复DDL路由 + * + * @return RouteResultset + * @author aStoneGod + */ + public static RouteResultset routeToDDLNode(RouteResultset rrs, int sqlType, String stmt,SchemaConfig schema) throws SQLSyntaxErrorException { + stmt = getFixedSql(stmt); + String tablename = ""; + final String upStmt = stmt.toUpperCase(); + if(upStmt.startsWith("CREATE")){ + if (upStmt.contains("CREATE INDEX ") || upStmt.contains("CREATE UNIQUE INDEX ")){ + tablename = RouterUtil.getTableName(stmt, RouterUtil.getCreateIndexPos(upStmt, 0)); + }else { + tablename = RouterUtil.getTableName(stmt, RouterUtil.getCreateTablePos(upStmt, 0)); + } + }else if(upStmt.startsWith("DROP")){ + if (upStmt.contains("DROP INDEX ")){ + tablename = RouterUtil.getTableName(stmt, RouterUtil.getDropIndexPos(upStmt, 0)); + }else { + tablename = RouterUtil.getTableName(stmt, RouterUtil.getDropTablePos(upStmt, 0)); + } + }else if(upStmt.startsWith("ALTER")){ + tablename = RouterUtil.getTableName(stmt, RouterUtil.getAlterTablePos(upStmt, 0)); + }else if (upStmt.startsWith("TRUNCATE")){ + tablename = RouterUtil.getTableName(stmt, RouterUtil.getTruncateTablePos(upStmt, 0)); + } + tablename = tablename.toUpperCase(); + + if (schema.getTables().containsKey(tablename)){ + if(ServerParse.DDL==sqlType){ + List dataNodes = new ArrayList<>(); + Map tables = schema.getTables(); + TableConfig tc=tables.get(tablename); + if (tables != null && (tc != null)) { + dataNodes = tc.getDataNodes(); + } + boolean isSlotFunction= tc.getRule() != null && tc.getRule().getRuleAlgorithm() instanceof SlotFunction; + Iterator iterator1 = dataNodes.iterator(); + int nodeSize = dataNodes.size(); + RouteResultsetNode[] nodes = new RouteResultsetNode[nodeSize]; + if(isSlotFunction){ + stmt=changeCreateTable(schema,tablename,stmt); + } + for(int i=0;i 0) { + tableName = tableName.substring(ind2 + 1); + } + return tableName; + } + + + /** + * 获取show语句table名字 + * + * @param stmt 执行语句 + * @param repPos 开始位置和位数 + * @return 表名 + * @author AStoneGod + */ + public static String getShowTableName(String stmt, int[] repPos) { + int startPos = repPos[0]; + int secInd = stmt.indexOf(' ', startPos + 1); + if (secInd < 0) { + secInd = stmt.length(); + } + + repPos[1] = secInd; + String tableName = stmt.substring(startPos, secInd).trim(); + + int ind2 = tableName.indexOf('.'); + if (ind2 > 0) { + tableName = tableName.substring(ind2 + 1); + } + return tableName; + } + + /** + * 获取语句中前关键字位置和占位个数表名位置 + * + * @param upStmt 执行语句 + * @param start 开始位置 + * @return int[] 关键字位置和占位个数 + * + * @author mycat + * + * @modification 修改支持语句中包含“IF NOT EXISTS”的情况 + * @date 2016/12/8 + * @modifiedBy Hash Zhang + */ + public static int[] getCreateTablePos(String upStmt, int start) { + String token1 = "CREATE "; + String token2 = " TABLE "; + String token3 = " EXISTS "; + int createInd = upStmt.indexOf(token1, start); + int tabInd1 = upStmt.indexOf(token2, start); + int tabInd2 = upStmt.indexOf(token3, tabInd1); + // 既包含CREATE又包含TABLE,且CREATE关键字在TABLE关键字之前 + if (createInd >= 0 && tabInd2 > 0 && tabInd2 > createInd) { + return new int[] { tabInd2, token3.length() }; + } else if(createInd >= 0 && tabInd1 > 0 && tabInd1 > createInd) { + return new int[] { tabInd1, token2.length() }; + } else { + return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 + } + } + + /** + * 获取语句中前关键字位置和占位个数表名位置 + * + * @param upStmt + * 执行语句 + * @param start + * 开始位置 + * @return int[]关键字位置和占位个数 + * @author aStoneGod + */ + public static int[] getCreateIndexPos(String upStmt, int start) { + String token1 = "CREATE "; + String token2 = " INDEX "; + String token3 = " ON "; + int createInd = upStmt.indexOf(token1, start); + int idxInd = upStmt.indexOf(token2, start); + int onInd = upStmt.indexOf(token3, start); + // 既包含CREATE又包含INDEX,且CREATE关键字在INDEX关键字之前, 且包含ON... + if (createInd >= 0 && idxInd > 0 && idxInd > createInd && onInd > 0 && onInd > idxInd) { + return new int[] {onInd , token3.length() }; + } else { + return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 + } + } + + /** + * 获取ALTER语句中前关键字位置和占位个数表名位置 + * + * @param upStmt 执行语句 + * @param start 开始位置 + * @return int[] 关键字位置和占位个数 + * @author aStoneGod + */ + public static int[] getAlterTablePos(String upStmt, int start) { + String token1 = "ALTER "; + String token2 = " TABLE "; + int createInd = upStmt.indexOf(token1, start); + int tabInd = upStmt.indexOf(token2, start); + // 既包含CREATE又包含TABLE,且CREATE关键字在TABLE关键字之前 + if (createInd >= 0 && tabInd > 0 && tabInd > createInd) { + return new int[] { tabInd, token2.length() }; + } else { + return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 + } + } + + /** + * 获取DROP语句中前关键字位置和占位个数表名位置 + * + * @param upStmt 执行语句 + * @param start 开始位置 + * @return int[] 关键字位置和占位个数 + * @author aStoneGod + */ + public static int[] getDropTablePos(String upStmt, int start) { + //增加 if exists判断 + if(upStmt.contains("EXISTS")){ + String token1 = "IF "; + String token2 = " EXISTS "; + int ifInd = upStmt.indexOf(token1, start); + int tabInd = upStmt.indexOf(token2, start); + if (ifInd >= 0 && tabInd > 0 && tabInd > ifInd) { + return new int[] { tabInd, token2.length() }; + } else { + return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 + } + }else { + String token1 = "DROP "; + String token2 = " TABLE "; + int createInd = upStmt.indexOf(token1, start); + int tabInd = upStmt.indexOf(token2, start); + + if (createInd >= 0 && tabInd > 0 && tabInd > createInd) { + return new int[] { tabInd, token2.length() }; + } else { + return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 + } + } + } + + + /** + * 获取DROP语句中前关键字位置和占位个数表名位置 + * + * @param upStmt + * 执行语句 + * @param start + * 开始位置 + * @return int[]关键字位置和占位个数 + * @author aStoneGod + */ + + public static int[] getDropIndexPos(String upStmt, int start) { + String token1 = "DROP "; + String token2 = " INDEX "; + String token3 = " ON "; + int createInd = upStmt.indexOf(token1, start); + int idxInd = upStmt.indexOf(token2, start); + int onInd = upStmt.indexOf(token3, start); + // 既包含CREATE又包含INDEX,且CREATE关键字在INDEX关键字之前, 且包含ON... + if (createInd >= 0 && idxInd > 0 && idxInd > createInd && onInd > 0 && onInd > idxInd) { + return new int[] {onInd , token3.length() }; + } else { + return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 + } + } + + /** + * 获取TRUNCATE语句中前关键字位置和占位个数表名位置 + * + * @param upStmt 执行语句 + * @param start 开始位置 + * @return int[] 关键字位置和占位个数 + * @author aStoneGod + */ + public static int[] getTruncateTablePos(String upStmt, int start) { + String token1 = "TRUNCATE "; + String token2 = " TABLE "; + int createInd = upStmt.indexOf(token1, start); + int tabInd = upStmt.indexOf(token2, start); + // 既包含CREATE又包含TABLE,且CREATE关键字在TABLE关键字之前 + if (createInd >= 0 && tabInd > 0 && tabInd > createInd) { + return new int[] { tabInd, token2.length() }; + } else { + return new int[] { -1, token2.length() };// 不满足条件时,只关注第一个返回值为-1,第二个任意 + } + } + + /** + * 获取语句中前关键字位置和占位个数表名位置 + * + * @param upStmt 执行语句 + * @param start 开始位置 + * @return int[] 关键字位置和占位个数 + * @author mycat + */ + public static int[] getSpecPos(String upStmt, int start) { + String token1 = " FROM "; + String token2 = " IN "; + int tabInd1 = upStmt.indexOf(token1, start); + int tabInd2 = upStmt.indexOf(token2, start); + if (tabInd1 > 0) { + if (tabInd2 < 0) { + return new int[] { tabInd1, token1.length() }; + } + return (tabInd1 < tabInd2) ? new int[] { tabInd1, token1.length() } + : new int[] { tabInd2, token2.length() }; + } else { + return new int[] { tabInd2, token2.length() }; + } + } + + /** + * 获取开始位置后的 LIKE、WHERE 位置 如果不含 LIKE、WHERE 则返回执行语句的长度 + * + * @param upStmt 执行sql + * @param start 开始位置 + * @return int + * @author mycat + */ + public static int getSpecEndPos(String upStmt, int start) { + int tabInd = upStmt.toUpperCase().indexOf(" LIKE ", start); + if (tabInd < 0) { + tabInd = upStmt.toUpperCase().indexOf(" WHERE ", start); + } + if (tabInd < 0) { + return upStmt.length(); + } + return tabInd; + } + + public static boolean processWithMycatSeq(SchemaConfig schema, int sqlType, + String origSQL, ServerConnection sc) { + // check if origSQL is with global sequence + // @micmiu it is just a simple judgement + //对应本地文件配置方式:insert into table1(id,name) values(next value for MYCATSEQ_GLOBAL,‘test’); + // edit by dingw,增加mycatseq_ 兼容,因为ServerConnection的373行,进行路由计算时,将原始语句全部转换为小写 + if (origSQL.indexOf(" MYCATSEQ_") != -1 || origSQL.indexOf("mycatseq_") != -1) { + processSQL(sc,schema,origSQL,sqlType); + return true; + } + return false; + } + + public static void processSQL(ServerConnection sc,SchemaConfig schema,String sql,int sqlType){ +// int sequenceHandlerType = MycatServer.getInstance().getConfig().getSystem().getSequnceHandlerType(); + final SessionSQLPair sessionSQLPair = new SessionSQLPair(sc.getSession2(), schema, sql, sqlType); +// modify by yanjunli 序列获取修改为多线程方式。使用分段锁方式,一个序列一把锁。 begin +// MycatServer.getInstance().getSequnceProcessor().addNewSql(sessionSQLPair); + MycatServer.getInstance().getSequenceExecutor().execute(new Runnable() { + @Override + public void run() { + MycatServer.getInstance().getSequnceProcessor().executeSeq(sessionSQLPair); + } + }); +// modify 序列获取修改为多线程方式。使用分段锁方式,一个序列一把锁。 end +// } + } + + public static boolean processInsert(SchemaConfig schema, int sqlType, + String origSQL, ServerConnection sc) throws SQLNonTransientException { + String tableName = StringUtil.getTableName(origSQL).toUpperCase(); + TableConfig tableConfig = schema.getTables().get(tableName); + boolean processedInsert=false; + //判断是有自增字段 + if (null != tableConfig && tableConfig.isAutoIncrement()) { + String primaryKey = tableConfig.getPrimaryKey(); + processedInsert=processInsert(sc,schema,sqlType,origSQL,tableName,primaryKey); + } + return processedInsert; + } + /* + * 找到返回主键的的位置 + * 找不到返回 -1 + * */ + private static int isPKInFields(String origSQL,String primaryKey,int firstLeftBracketIndex,int firstRightBracketIndex){ + + if (primaryKey == null) { + throw new RuntimeException("please make sure the primaryKey's config is not null in schemal.xml"); + } + + boolean isPrimaryKeyInFields = false; + int pkStart = 0; + String upperSQL = origSQL.substring(firstLeftBracketIndex, firstRightBracketIndex + 1).toUpperCase(); + for (int pkOffset = 0, primaryKeyLength = primaryKey.length();;) { + pkStart = upperSQL.indexOf(primaryKey, pkOffset); + if (pkStart >= 0 && pkStart < firstRightBracketIndex) { + char pkSide = upperSQL.charAt(pkStart - 1); + if (pkSide <= ' ' || pkSide == '`' || pkSide == ',' || pkSide == '(') { + pkSide = upperSQL.charAt(pkStart + primaryKey.length()); + isPrimaryKeyInFields = pkSide <= ' ' || pkSide == '`' || pkSide == ',' || pkSide == ')'; + } + if (isPrimaryKeyInFields) { + break; + } + pkOffset = pkStart + primaryKeyLength; + } else { + break; + } + } + if (isPrimaryKeyInFields) { + return firstLeftBracketIndex + pkStart; + } else { + return -1; + } + + } + + public static boolean processInsert(ServerConnection sc,SchemaConfig schema, + int sqlType,String origSQL,String tableName,String primaryKey) throws SQLNonTransientException { + + int firstLeftBracketIndex = origSQL.indexOf("("); + int firstRightBracketIndex = origSQL.indexOf(")"); + String upperSql = origSQL.toUpperCase(); + int valuesIndex = upperSql.indexOf("VALUES"); + int selectIndex = upperSql.indexOf("SELECT"); + int fromIndex = upperSql.indexOf("FROM"); + //屏蔽insert into table1 select * from table2语句 + if(firstLeftBracketIndex < 0) { + String msg = "invalid sql:" + origSQL; + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + //屏蔽批量插入 + if(selectIndex > 0 &&fromIndex>0&&selectIndex>firstRightBracketIndex&&valuesIndex<0) { + String msg = "multi insert not provided" ; + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + //插入语句必须提供列结构,因为MyCat默认对于表结构无感知 + if(valuesIndex + "VALUES".length() <= firstLeftBracketIndex) { + throw new SQLSyntaxErrorException("insert must provide ColumnList"); + } + List> vauleList = parseSqlValue(origSQL , valuesIndex); + //两种情况处理 1 有主键的 id ,但是值为null 进行改下 + // 2 没有主键的 需要插入 进行改写 + + //如果主键不在插入语句的fields中,则需要进一步处理 + boolean processedInsert= false; + int pkStart = isPKInFields(origSQL,primaryKey,firstLeftBracketIndex,firstRightBracketIndex); + + + if(pkStart == -1){ + processedInsert = true; + handleBatchInsert(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleList); + } else { + //判断 主键id的值是否为null + if(pkStart != -1) { + String subPrefix = origSQL.substring(0, pkStart); + char c; + int pkIndex = 0; + for(int index = 0, len = subPrefix.length(); index < len; index++) { + c = subPrefix.charAt(index); + if(c == ',') { + pkIndex ++; + } + } + processedInsert = handleBatchInsertWithPK(sc, schema, sqlType,origSQL, valuesIndex, tableName, primaryKey, vauleList , pkIndex); + } + } + return processedInsert; + } + + private static boolean handleBatchInsertWithPK(ServerConnection sc, SchemaConfig schema, int sqlType, + String origSQL, int valuesIndex, String tableName, String primaryKey, List> vauleList, + int pkIndex) { + boolean processedInsert = false; +// final String pk = "\\("+primaryKey+","; + final String mycatSeqPrefix = "next value for MYCATSEQ_"+tableName.toUpperCase() ; + + /*"VALUES".length() ==6 */ + String prefix = origSQL.substring(0, valuesIndex + 6); +// + + StringBuilder sb = new StringBuilder(""); + for(List list : vauleList) { + sb.append("("); + String pkValue = list.get(pkIndex).trim().toLowerCase(); + //null值替换为 next value for MYCATSEQ_tableName + if("null".equals(pkValue.trim())) { + list.set(pkIndex, mycatSeqPrefix); + processedInsert = true; + } + for(String val : list) { + sb.append(val).append(","); + } + sb.setCharAt(sb.length() - 1, ')'); + sb.append(","); + } + sb.setCharAt(sb.length() - 1, ' ');; + if(processedInsert) { + processSQL(sc, schema,prefix+sb.toString(), sqlType); + + } + return processedInsert; + } + + public static List handleBatchInsert(String origSQL, int valuesIndex) { + List handledSQLs = new LinkedList<>(); + String prefix = origSQL.substring(0, valuesIndex + "VALUES".length()); + String values = origSQL.substring(valuesIndex + "VALUES".length()); + int flag = 0; + StringBuilder currentValue = new StringBuilder(); + currentValue.append(prefix); + for (int i = 0; i < values.length(); i++) { + char j = values.charAt(i); + if (j == '(' && flag == 0) { + flag = 1; + currentValue.append(j); + } else if (j == '\"' && flag == 1) { + flag = 2; + currentValue.append(j); + } else if (j == '\'' && flag == 1) { + flag = 3; + currentValue.append(j); + } else if (j == '\\' && flag == 2) { + flag = 4; + currentValue.append(j); + } else if (j == '\\' && flag == 3) { + flag = 5; + currentValue.append(j); + } else if (flag == 4) { + flag = 2; + currentValue.append(j); + } else if (flag == 5) { + flag = 3; + currentValue.append(j); + } else if (j == '\"' && flag == 2) { + flag = 1; + currentValue.append(j); + } else if (j == '\'' && flag == 3) { + flag = 1; + currentValue.append(j); + } else if (j == ')' && flag == 1) { + flag = 0; + currentValue.append(j); + handledSQLs.add(currentValue.toString()); + currentValue = new StringBuilder(); + currentValue.append(prefix); + } else if (j == ',' && flag == 0) { + continue; + } else { + currentValue.append(j); + } + } + return handledSQLs; + } + /** + * 对于插入的sql : "insert into hotnews(title,name) values('test1',\"name\"),('(test)',\"(test)\"),('\\\"',\"\\'\"),(\")\",\"\\\"\\')\")": + * 需要返回结果: + *[[ 'test1', "name"], + * ['(test)', "(test)"], + * ['\"', "\'"], + * [")", "\"\')"], + * [ 1, null] + * 值结果的解析 + */ + public static List> parseSqlValue(String origSQL,int valuesIndex ) { + List> valueArray = new ArrayList<>(); + String valueStr = origSQL.substring(valuesIndex + 6);// 6 values 长度为6 + String preStr = origSQL.substring(0, valuesIndex );// 6 values 长度为6 + int pos = 0 ; + int flag = -1; + int len = valueStr.length(); + StringBuilder currentValue = new StringBuilder(); +// int colNum = 2; // + char c ; + List curList = new ArrayList<>(); + for( ;pos < len; pos ++) { + c = valueStr.charAt(pos); + if(flag == 1 || flag == 2) { + currentValue.append(c); + if(c == '\\') { + char nextCode = valueStr.charAt(pos + 1); + if(nextCode == '\'' || nextCode == '\"') { + currentValue.append(nextCode); + pos++; + continue; + } + } + if(c == '\"' && flag == 1) { + flag = 0; + continue; + } + if(c == '\'' && flag == 2) { + flag = 0; + continue; + } + } else if(c == '\"'){ + currentValue.append(c); + flag = 1; + } else if (c == '\'') { + currentValue.append(c); + flag = 2; + } else if (c == '(') { + curList = new ArrayList<>(); + flag = 0; + } else if(flag == 4 ) { + if(c == ',') { + flag = 0; + continue; + } + } else if(c == ',') { +// System.out.println(currentValue); + curList.add(currentValue.toString()); + currentValue.delete(0, currentValue.length()); + } else if(c == ')'){ + flag = 4; +// System.out.println(currentValue); + curList.add(currentValue.toString()); + currentValue.delete(0, currentValue.length()); + valueArray.add(curList); + } else { + currentValue.append(c); + } + } + return valueArray; + } + + /** + * 对于主键不在插入语句的fields中的SQL,需要改写。比如hotnews主键为id,插入语句为: + * insert into hotnews(title) values('aaa'); + * 需要改写成: + * insert into hotnews(id, title) values(next value for MYCATSEQ_hotnews,'aaa'); + */ + public static void handleBatchInsert(ServerConnection sc, SchemaConfig schema, + int sqlType,String origSQL, int valuesIndex,String tableName, String primaryKey , List> vauleList) { + + final String pk = "\\("+primaryKey+","; + final String mycatSeqPrefix = "(next value for MYCATSEQ_"+tableName.toUpperCase()+""; + + /*"VALUES".length() ==6 */ + String prefix = origSQL.substring(0, valuesIndex + 6); +// + prefix = prefix.replaceFirst("\\(", pk); + + StringBuilder sb = new StringBuilder(""); + for(List list : vauleList) { + sb.append(mycatSeqPrefix); + for(String val : list) { + sb.append(",").append(val); + } + sb.append("),"); + } + sb.setCharAt(sb.length() - 1, ' ');; + processSQL(sc, schema,prefix+sb.toString(), sqlType); + } +// /** +// * 对于主键不在插入语句的fields中的SQL,需要改写。比如hotnews主键为id,插入语句为: +// * insert into hotnews(title) values('aaa'); +// * 需要改写成: +// * insert into hotnews(id, title) values(next value for MYCATSEQ_hotnews,'aaa'); +// */ +// public static void handleBatchInsert(ServerConnection sc, SchemaConfig schema, +// int sqlType,String origSQL, int valuesIndex,String tableName, String primaryKey) { +// +// final String pk = "\\("+primaryKey+","; +// final String mycatSeqPrefix = "(next value for MYCATSEQ_"+tableName.toUpperCase()+","; +// +// /*"VALUES".length() ==6 */ +// String prefix = origSQL.substring(0, valuesIndex + 6); +// String values = origSQL.substring(valuesIndex + 6); +// +// prefix = prefix.replaceFirst("\\(", pk); +// values = values.replaceFirst("\\(", mycatSeqPrefix); +// values =Pattern.compile(",\\s*\\(").matcher(values).replaceAll(","+mycatSeqPrefix); +// processSQL(sc, schema,prefix+values, sqlType); +// } + + + public static RouteResultset routeToMultiNode(boolean cache,RouteResultset rrs, Collection dataNodes, String stmt) { + RouteResultsetNode[] nodes = new RouteResultsetNode[dataNodes.size()]; + int i = 0; + RouteResultsetNode node; + for (String dataNode : dataNodes) { + node = new RouteResultsetNode(dataNode, rrs.getSqlType(), stmt); + node.setSource(rrs); + if(rrs.getDataNodeSlotMap().containsKey(dataNode)){ + node.setSlot(rrs.getDataNodeSlotMap().get(dataNode)); + } + if (rrs.getCanRunInReadDB() != null) { + node.setCanRunInReadDB(rrs.getCanRunInReadDB()); + } + if(rrs.getRunOnSlave() != null){ + nodes[0].setRunOnSlave(rrs.getRunOnSlave()); + } + nodes[i++] = node; + } + rrs.setCacheAble(cache); + rrs.setNodes(nodes); + return rrs; + } + + public static RouteResultset routeToMultiNode(boolean cache, RouteResultset rrs, Collection dataNodes, + String stmt, boolean isGlobalTable) { + + rrs = routeToMultiNode(cache, rrs, dataNodes, stmt); + rrs.setGlobalTable(isGlobalTable); + return rrs; + } + + public static void routeForTableMeta(RouteResultset rrs, + SchemaConfig schema, String tableName, String sql) { + String dataNode = null; + if (isNoSharding(schema,tableName)) {//不分库的直接从schema中获取dataNode + dataNode = schema.getDataNode(); + } else { + dataNode = getMetaReadDataNode(schema, tableName); + } + + RouteResultsetNode[] nodes = new RouteResultsetNode[1]; + nodes[0] = new RouteResultsetNode(dataNode, rrs.getSqlType(), sql); + nodes[0].setSource(rrs); + if(rrs.getDataNodeSlotMap().containsKey(dataNode)){ + nodes[0].setSlot(rrs.getDataNodeSlotMap().get(dataNode)); + } + if (rrs.getCanRunInReadDB() != null) { + nodes[0].setCanRunInReadDB(rrs.getCanRunInReadDB()); + } + if(rrs.getRunOnSlave() != null){ + nodes[0].setRunOnSlave(rrs.getRunOnSlave()); + } + rrs.setNodes(nodes); + } + + /** + * 根据表名随机获取一个节点 + * + * @param schema 数据库名 + * @param table 表名 + * @return 数据节点 + * @author mycat + */ + private static String getMetaReadDataNode(SchemaConfig schema, + String table) { + // Table名字被转化为大写的,存储在schema + table = table.toUpperCase(); + String dataNode = null; + Map tables = schema.getTables(); + TableConfig tc; + if (tables != null && (tc = tables.get(table)) != null) { + dataNode = getAliveRandomDataNode(tc); + } + return dataNode; + } + + /** + * 解决getRandomDataNode方法获取错误节点的问题. + * @param tc + * @return + */ + private static String getAliveRandomDataNode(TableConfig tc) { + List randomDns = (List)tc.getDataNodes().clone(); + + MycatConfig mycatConfig = MycatServer.getInstance().getConfig(); + if (mycatConfig != null) { + Collections.shuffle(randomDns); + for (String randomDn : randomDns) { + PhysicalDBNode physicalDBNode = mycatConfig.getDataNodes().get(randomDn); + if (physicalDBNode != null) { + if (physicalDBNode.getDbPool().getSource().isAlive()) { + for (PhysicalDBPool pool : MycatServer.getInstance().getConfig().getDataHosts().values()) { + PhysicalDatasource source = pool.getSource(); + if (source.getHostConfig().containDataNode(randomDn) && pool.getSource().isAlive()) { + return randomDn; + } + } + } + } + } + } + + // all fail return default + return tc.getRandomDataNode(); + } + + @Deprecated + private static String getRandomDataNode(TableConfig tc) { + //写节点不可用,意味着读节点也不可用。 + //直接使用下一个 dataHost + String randomDn = tc.getRandomDataNode(); + MycatConfig mycatConfig = MycatServer.getInstance().getConfig(); + if (mycatConfig != null) { + PhysicalDBNode physicalDBNode = mycatConfig.getDataNodes().get(randomDn); + if (physicalDBNode != null) { + if (physicalDBNode.getDbPool().getSource().isAlive()) { + for (PhysicalDBPool pool : MycatServer.getInstance() + .getConfig() + .getDataHosts() + .values()) { + if (pool.getSource().getHostConfig().containDataNode(randomDn)) { + continue; + } + + if (pool.getSource().isAlive()) { + return pool.getSource().getHostConfig().getRandomDataNode(); + } + } + } + } + } + + //all fail return default + return randomDn; + } + + /** + * 根据 ER分片规则获取路由集合 + * + * @param stmt 执行的语句 + * @param rrs 数据路由集合 + * @param tc 表实体 + * @param joinKeyVal 连接属性 + * @return RouteResultset(数据路由集合) * + * @throws SQLNonTransientException,IllegalShardingColumnValueException + * @author mycat + */ + + public static RouteResultset routeByERParentKey(ServerConnection sc,SchemaConfig schema, + int sqlType,String stmt, + RouteResultset rrs, TableConfig tc, String joinKeyVal) + throws SQLNonTransientException { + + // only has one parent level and ER parent key is parent + // table's partition key + if (tc.isSecondLevel() + //判断是否为二级子表(父表不再有父表) + && tc.getParentTC().getPartitionColumn() + .equals(tc.getParentKey())) { // using + // parent + // rule to + // find + // datanode + Set parentColVal = new HashSet(1); + ColumnRoutePair pair = new ColumnRoutePair(joinKeyVal); + parentColVal.add(pair); + Set dataNodeSet = ruleCalculate(tc.getParentTC(), parentColVal,rrs.getDataNodeSlotMap()); + if (dataNodeSet.isEmpty() || dataNodeSet.size() > 1) { + throw new SQLNonTransientException( + "parent key can't find valid datanode ,expect 1 but found: " + + dataNodeSet.size()); + } + String dn = dataNodeSet.iterator().next(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("found partion node (using parent partion rule directly) for child table to insert " + + dn + " sql :" + stmt); + } + return RouterUtil.routeToSingleNode(rrs, dn, stmt); + } + return null; + } + + /** + * @return dataNodeIndex -> [partitionKeysValueTuple+] + */ + public static Set ruleByJoinValueCalculate(RouteResultset rrs, TableConfig tc, + Set colRoutePairSet) throws SQLNonTransientException { + + String joinValue = ""; + + if(colRoutePairSet.size() > 1) { + LOGGER.warn("joinKey can't have multi Value"); + } else { + Iterator it = colRoutePairSet.iterator(); + ColumnRoutePair joinCol = it.next(); + joinValue = joinCol.colValue; + } + + Set retNodeSet = new LinkedHashSet(); + + Set nodeSet; + if (tc.isSecondLevel() + && tc.getParentTC().getPartitionColumn() + .equals(tc.getParentKey())) { // using + // parent + // rule to + // find + // datanode + + nodeSet = ruleCalculate(tc.getParentTC(),colRoutePairSet,rrs.getDataNodeSlotMap()); + if (nodeSet.isEmpty()) { + throw new SQLNonTransientException( + "parent key can't find valid datanode ,expect 1 but found: " + + nodeSet.size()); + } + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("found partion node (using parent partion rule directly) for child table to insert " + + nodeSet + " sql :" + rrs.getStatement()); + } + retNodeSet.addAll(nodeSet); + +// for(ColumnRoutePair pair : colRoutePairSet) { +// nodeSet = ruleCalculate(tc.getParentTC(),colRoutePairSet); +// if (nodeSet.isEmpty() || nodeSet.size() > 1) {//an exception would be thrown, if sql was executed on more than on sharding +// throw new SQLNonTransientException( +// "parent key can't find valid datanode ,expect 1 but found: " +// + nodeSet.size()); +// } +// String dn = nodeSet.iterator().next(); +// if (LOGGER.isDebugEnabled()) { +// LOGGER.debug("found partion node (using parent partion rule directly) for child table to insert " +// + dn + " sql :" + rrs.getStatement()); +// } +// retNodeSet.addAll(nodeSet); +// } + return retNodeSet; + } else { + retNodeSet.addAll(tc.getParentTC().getDataNodes()); + } + + return retNodeSet; + } + + + /** + * @return dataNodeIndex -> [partitionKeysValueTuple+] + */ + public static Set ruleCalculate(TableConfig tc, + Set colRoutePairSet,Map dataNodeSlotMap) { + Set routeNodeSet = new LinkedHashSet(); + String col = tc.getRule().getColumn(); + RuleConfig rule = tc.getRule(); + AbstractPartitionAlgorithm algorithm = rule.getRuleAlgorithm(); + for (ColumnRoutePair colPair : colRoutePairSet) { + if (colPair.colValue != null) { + Integer nodeIndx = algorithm.calculate(colPair.colValue); + if (nodeIndx == null) { + throw new IllegalArgumentException( + "can't find datanode for sharding column:" + col + + " val:" + colPair.colValue); + } else { + String dataNode = tc.getDataNodes().get(nodeIndx); + routeNodeSet.add(dataNode); + if(algorithm instanceof SlotFunction) { + dataNodeSlotMap.put(dataNode,((SlotFunction) algorithm).slotValue()); + } + colPair.setNodeId(nodeIndx); + } + } else if (colPair.rangeValue != null) { + Integer[] nodeRange = algorithm.calculateRange( + String.valueOf(colPair.rangeValue.beginValue), + String.valueOf(colPair.rangeValue.endValue)); + if (nodeRange != null) { + /** + * 不能确认 colPair的 nodeid是否会有其它影响 + */ + if (nodeRange.length == 0) { + routeNodeSet.addAll(tc.getDataNodes()); + } else { + ArrayList dataNodes = tc.getDataNodes(); + String dataNode = null; + for (Integer nodeId : nodeRange) { + dataNode = dataNodes.get(nodeId); + if(algorithm instanceof SlotFunction) { + dataNodeSlotMap.put(dataNode,((SlotFunction) algorithm).slotValue()); + } + routeNodeSet.add(dataNode); + } + } + } + } + + } + return routeNodeSet; + } + + /** + * 多表路由 + */ + public static RouteResultset tryRouteForTables(SchemaConfig schema, DruidShardingParseInfo ctx, + RouteCalculateUnit routeUnit, RouteResultset rrs, boolean isSelect, LayerCachePool cachePool) + throws SQLNonTransientException { + + List tables = ctx.getTables(); + + if(schema.isNoSharding()||(tables.size() >= 1&&isNoSharding(schema,tables.get(0)))) { + return routeToSingleNode(rrs, schema.getDataNode(), ctx.getSql()); + } + + //只有一个表的 + if(tables.size() == 1) { + return RouterUtil.tryRouteForOneTable(schema, ctx, routeUnit, tables.get(0), rrs, isSelect, cachePool); + } + + Set retNodesSet = new HashSet(); + //每个表对应的路由映射 + Map> tablesRouteMap = new HashMap>(); + + //分库解析信息不为空 + Map>> tablesAndConditions = routeUnit.getTablesAndConditions(); + if(tablesAndConditions != null && tablesAndConditions.size() > 0) { + //为分库表找路由 + RouterUtil.findRouteWithcConditionsForTables(schema, rrs, tablesAndConditions, tablesRouteMap, ctx.getSql(), cachePool, isSelect); + if(rrs.isFinishedRoute()) { + return rrs; + } + } + + //为全局表和单库表找路由 + for(String tableName : tables) { + + TableConfig tableConfig = schema.getTables().get(tableName.toUpperCase()); + + if(tableConfig == null) { + //add 如果表读取不到则先将表名从别名中读取转化后再读取 + String alias = ctx.getTableAliasMap().get(tableName); + if(!StringUtil.isEmpty(alias)){ + tableConfig = schema.getTables().get(alias.toUpperCase()); + } + + if(tableConfig == null){ + String msg = "can't find table define in schema "+ tableName + " schema:" + schema.getName(); + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + + } + if(tableConfig.isGlobalTable()) {//全局表 + if(tablesRouteMap.get(tableName) == null) { + tablesRouteMap.put(tableName, new HashSet()); + } + tablesRouteMap.get(tableName).addAll(tableConfig.getDataNodes()); + } else if(tablesRouteMap.get(tableName) == null) { //余下的表都是单库表 + tablesRouteMap.put(tableName, new HashSet()); + tablesRouteMap.get(tableName).addAll(tableConfig.getDataNodes()); + } + } + + boolean isFirstAdd = true; + for(Map.Entry> entry : tablesRouteMap.entrySet()) { + if(entry.getValue() == null || entry.getValue().size() == 0) { + throw new SQLNonTransientException("parent key can't find any valid datanode "); + } else { + if(isFirstAdd) { + retNodesSet.addAll(entry.getValue()); + isFirstAdd = false; + } else { + retNodesSet.retainAll(entry.getValue()); + if(retNodesSet.size() == 0) {//两个表的路由无交集 + String errMsg = "invalid route in sql, multi tables found but datanode has no intersection " + + " sql:" + ctx.getSql(); + LOGGER.warn(errMsg); + throw new SQLNonTransientException(errMsg); + } + } + } + } + + if(retNodesSet != null && retNodesSet.size() > 0) { + String tableName = tables.get(0); + TableConfig tableConfig = schema.getTables().get(tableName.toUpperCase()); + if(tableConfig.isDistTable()){ + routeToDistTableNode(tableName,schema, rrs, ctx.getSql(), tablesAndConditions, cachePool, isSelect); + return rrs; + } + + if(retNodesSet.size() > 1 && isAllGlobalTable(ctx, schema)) { + // mulit routes ,not cache route result + if (isSelect) { + rrs.setCacheAble(false); + ArrayList retNodeList = new ArrayList(retNodesSet); + Collections.shuffle(retNodeList);//by kaiz : add shuffle + routeToSingleNode(rrs, retNodeList.get(0), ctx.getSql()); + } + else {//delete 删除全局表的记录 + routeToMultiNode(isSelect, rrs, retNodesSet, ctx.getSql(),true); + } + + } else { + routeToMultiNode(isSelect, rrs, retNodesSet, ctx.getSql()); + } + + } + return rrs; + + } + + + /** + * + * 单表路由 + */ + public static RouteResultset tryRouteForOneTable(SchemaConfig schema, DruidShardingParseInfo ctx, + RouteCalculateUnit routeUnit, String tableName, RouteResultset rrs, boolean isSelect, + LayerCachePool cachePool) throws SQLNonTransientException { + + if (isNoSharding(schema, tableName)) { + return routeToSingleNode(rrs, schema.getDataNode(), ctx.getSql()); + } + + TableConfig tc = schema.getTables().get(tableName); + if(tc == null) { + String msg = "can't find table define in schema " + tableName + " schema:" + schema.getName(); + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + + if(tc.isDistTable()){ + return routeToDistTableNode(tableName,schema,rrs,ctx.getSql(), routeUnit.getTablesAndConditions(), cachePool,isSelect); + } + + if(tc.isGlobalTable()) {//全局表 + if(isSelect) { + // global select ,not cache route result + rrs.setCacheAble(false); + return routeToSingleNode(rrs, getAliveRandomDataNode(tc)/*getRandomDataNode(tc)*/, ctx.getSql()); + } else {//insert into 全局表的记录 + return routeToMultiNode(false, rrs, tc.getDataNodes(), ctx.getSql(),true); + } + } else {//单表或者分库表 + if (!checkRuleRequired(schema, ctx, routeUnit, tc)) { + throw new IllegalArgumentException("route rule for table " + + tc.getName() + " is required: " + ctx.getSql()); + + } + if(tc.getPartitionColumn() == null && !tc.isSecondLevel()) {//单表且不是childTable +// return RouterUtil.routeToSingleNode(rrs, tc.getDataNodes().get(0),ctx.getSql()); + return routeToMultiNode(rrs.isCacheAble(), rrs, tc.getDataNodes(), ctx.getSql()); + } else { + //每个表对应的路由映射 + Map> tablesRouteMap = new HashMap>(); + if(routeUnit.getTablesAndConditions() != null && routeUnit.getTablesAndConditions().size() > 0) { + RouterUtil.findRouteWithcConditionsForTables(schema, rrs, routeUnit.getTablesAndConditions(), tablesRouteMap, ctx.getSql(), cachePool, isSelect); + if(rrs.isFinishedRoute()) { + return rrs; + } + } + + if(tablesRouteMap.get(tableName) == null) { + return routeToMultiNode(rrs.isCacheAble(), rrs, tc.getDataNodes(), ctx.getSql()); + } else { + return routeToMultiNode(rrs.isCacheAble(), rrs, tablesRouteMap.get(tableName), ctx.getSql()); + } + } + } + } + + private static RouteResultset routeToDistTableNode(String tableName, SchemaConfig schema, RouteResultset rrs, + String orgSql, Map>> tablesAndConditions, + LayerCachePool cachePool, boolean isSelect) throws SQLNonTransientException { + + TableConfig tableConfig = schema.getTables().get(tableName); + if(tableConfig == null) { + String msg = "can't find table define in schema " + tableName + " schema:" + schema.getName(); + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + if(tableConfig.isGlobalTable()){ + String msg = "can't suport district table " + tableName + " schema:" + schema.getName() + " for global table "; + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + String partionCol = tableConfig.getPartitionColumn(); +// String primaryKey = tableConfig.getPrimaryKey(); + boolean isLoadData=false; + + Set tablesRouteSet = new HashSet(); + + List dataNodes = tableConfig.getDataNodes(); + if(dataNodes.size()>1){ + String msg = "can't suport district table " + tableName + " schema:" + schema.getName() + " for mutiple dataNode " + dataNodes; + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + String dataNode = dataNodes.get(0); + + //主键查找缓存暂时不实现 + if(tablesAndConditions.isEmpty()){ + List subTables = tableConfig.getDistTables(); + tablesRouteSet.addAll(subTables); + } + + for(Map.Entry>> entry : tablesAndConditions.entrySet()) { + boolean isFoundPartitionValue = partionCol != null && entry.getValue().get(partionCol) != null; + Map> columnsMap = entry.getValue(); + + Set partitionValue = columnsMap.get(partionCol); + if(partitionValue == null || partitionValue.size() == 0) { + tablesRouteSet.addAll(tableConfig.getDistTables()); + } else { + for(ColumnRoutePair pair : partitionValue) { + AbstractPartitionAlgorithm algorithm = tableConfig.getRule().getRuleAlgorithm(); + if(pair.colValue != null) { + Integer tableIndex = algorithm.calculate(pair.colValue); + if(tableIndex == null) { + String msg = "can't find any valid datanode :" + tableConfig.getName() + + " -> " + tableConfig.getPartitionColumn() + " -> " + pair.colValue; + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + String subTable = tableConfig.getDistTables().get(tableIndex); + if(subTable != null) { + tablesRouteSet.add(subTable); + if(algorithm instanceof SlotFunction){ + rrs.getDataNodeSlotMap().put(subTable,((SlotFunction) algorithm).slotValue()); + } + } + } + if(pair.rangeValue != null) { + Integer[] tableIndexs = algorithm + .calculateRange(pair.rangeValue.beginValue.toString(), pair.rangeValue.endValue.toString()); + for(Integer idx : tableIndexs) { + String subTable = tableConfig.getDistTables().get(idx); + if(subTable != null) { + tablesRouteSet.add(subTable); + if(algorithm instanceof SlotFunction){ + rrs.getDataNodeSlotMap().put(subTable,((SlotFunction) algorithm).slotValue()); + } + } + } + } + } + } + } + + Object[] subTables = tablesRouteSet.toArray(); + RouteResultsetNode[] nodes = new RouteResultsetNode[subTables.length]; + Map dataNodeSlotMap= rrs.getDataNodeSlotMap(); + for(int i=0;i>> tablesAndConditions, + Map> tablesRouteMap, String sql, LayerCachePool cachePool, boolean isSelect) + throws SQLNonTransientException { + + //为分库表找路由 + for(Map.Entry>> entry : tablesAndConditions.entrySet()) { + String tableName = entry.getKey().toUpperCase(); + TableConfig tableConfig = schema.getTables().get(tableName); + if(tableConfig == null) { + String msg = "can't find table define in schema " + + tableName + " schema:" + schema.getName(); + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + if(tableConfig.getDistTables()!=null && tableConfig.getDistTables().size()>0){ + routeToDistTableNode(tableName,schema,rrs,sql, tablesAndConditions, cachePool,isSelect); + } + //全局表或者不分库的表略过(全局表后面再计算) + if(tableConfig.isGlobalTable() || schema.getTables().get(tableName).getDataNodes().size() == 1) { + continue; + } else {//非全局表:分库表、childTable、其他 + Map> columnsMap = entry.getValue(); + String joinKey = tableConfig.getJoinKey(); + String partionCol = tableConfig.getPartitionColumn(); + String primaryKey = tableConfig.getPrimaryKey(); + boolean isFoundPartitionValue = partionCol != null && entry.getValue().get(partionCol) != null; + boolean isLoadData=false; + if (LOGGER.isDebugEnabled() + && sql.startsWith(LoadData.loadDataHint)||rrs.isLoadData()) { + //由于load data一次会计算很多路由数据,如果输出此日志会极大降低load data的性能 + isLoadData=true; + } + if(entry.getValue().get(primaryKey) != null && entry.getValue().size() == 1&&!isLoadData) + {//主键查找 + // try by primary key if found in cache + Set primaryKeyPairs = entry.getValue().get(primaryKey); + if (primaryKeyPairs != null) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("try to find cache by primary key "); + } + String tableKey = schema.getName() + '_' + tableName; + boolean allFound = true; + for (ColumnRoutePair pair : primaryKeyPairs) {//可能id in(1,2,3)多主键 + String cacheKey = pair.colValue; + String dataNode = (String) cachePool.get(tableKey, cacheKey); + if (dataNode == null) { + allFound = false; + continue; + } else { + if(tablesRouteMap.get(tableName) == null) { + tablesRouteMap.put(tableName, new HashSet()); + } + tablesRouteMap.get(tableName).add(dataNode); + continue; + } + } + if (!allFound) { + // need cache primary key ->datanode relation + if (isSelect && tableConfig.getPrimaryKey() != null) { + rrs.setPrimaryKey(tableKey + '.' + tableConfig.getPrimaryKey()); + } + } else {//主键缓存中找到了就执行循环的下一轮 + continue; + } + } + } + if (isFoundPartitionValue) {//分库表 + Set partitionValue = columnsMap.get(partionCol); + if(partitionValue == null || partitionValue.size() == 0) { + if(tablesRouteMap.get(tableName) == null) { + tablesRouteMap.put(tableName, new HashSet()); + } + tablesRouteMap.get(tableName).addAll(tableConfig.getDataNodes()); + } else { + for(ColumnRoutePair pair : partitionValue) { + AbstractPartitionAlgorithm algorithm = tableConfig.getRule().getRuleAlgorithm(); + if(pair.colValue != null) { + Integer nodeIndex = algorithm.calculate(pair.colValue); + if(nodeIndex == null) { + String msg = "can't find any valid datanode :" + tableConfig.getName() + + " -> " + tableConfig.getPartitionColumn() + " -> " + pair.colValue; + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + + ArrayList dataNodes = tableConfig.getDataNodes(); + String node; + if (nodeIndex >=0 && nodeIndex < dataNodes.size()) { + node = dataNodes.get(nodeIndex); + + } else { + node = null; + String msg = "Can't find a valid data node for specified node index :" + + tableConfig.getName() + " -> " + tableConfig.getPartitionColumn() + + " -> " + pair.colValue + " -> " + "Index : " + nodeIndex; + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + if(node != null) { + if(tablesRouteMap.get(tableName) == null) { + tablesRouteMap.put(tableName, new HashSet()); + } + if(algorithm instanceof SlotFunction){ + rrs.getDataNodeSlotMap().put(node,((SlotFunction) algorithm).slotValue()); + } + tablesRouteMap.get(tableName).add(node); + } + } + if(pair.rangeValue != null) { + Integer[] nodeIndexs = algorithm + .calculateRange(pair.rangeValue.beginValue.toString(), pair.rangeValue.endValue.toString()); + ArrayList dataNodes = tableConfig.getDataNodes(); + String node; + for(Integer idx : nodeIndexs) { + if (idx >= 0 && idx < dataNodes.size()) { + node = dataNodes.get(idx); + } else { + String msg = "Can't find valid data node(s) for some of specified node indexes :" + + tableConfig.getName() + " -> " + tableConfig.getPartitionColumn(); + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + if(node != null) { + if(tablesRouteMap.get(tableName) == null) { + tablesRouteMap.put(tableName, new HashSet()); + } + if(algorithm instanceof SlotFunction){ + rrs.getDataNodeSlotMap().put(node,((SlotFunction) algorithm).slotValue()); + } + tablesRouteMap.get(tableName).add(node); + + } + } + } + } + } + } else if(joinKey != null && columnsMap.get(joinKey) != null && columnsMap.get(joinKey).size() != 0) {//childTable (如果是select 语句的父子表join)之前要找到root table,将childTable移除,只留下root table + Set joinKeyValue = columnsMap.get(joinKey); + + Set dataNodeSet = ruleByJoinValueCalculate(rrs, tableConfig, joinKeyValue); + + if (dataNodeSet.isEmpty()) { + throw new SQLNonTransientException( + "parent key can't find any valid datanode "); + } + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("found partion nodes (using parent partion rule directly) for child table to update " + + Arrays.toString(dataNodeSet.toArray()) + " sql :" + sql); + } + if (dataNodeSet.size() > 1) { + routeToMultiNode(rrs.isCacheAble(), rrs, dataNodeSet, sql); + rrs.setFinishedRoute(true); + return; + } else { + rrs.setCacheAble(true); + routeToSingleNode(rrs, dataNodeSet.iterator().next(), sql); + return; + } + + } else { + //没找到拆分字段,该表的所有节点都路由 + if(tablesRouteMap.get(tableName) == null) { + tablesRouteMap.put(tableName, new HashSet()); + } + boolean isSlotFunction= tableConfig.getRule() != null && tableConfig.getRule().getRuleAlgorithm() instanceof SlotFunction; + if(isSlotFunction){ + for (String dn : tableConfig.getDataNodes()) { + rrs.getDataNodeSlotMap().put(dn,-1); + } + } + tablesRouteMap.get(tableName).addAll(tableConfig.getDataNodes()); + } + } + } + } + + public static boolean isAllGlobalTable(DruidShardingParseInfo ctx, SchemaConfig schema) { + boolean isAllGlobal = false; + for(String table : ctx.getTables()) { + TableConfig tableConfig = schema.getTables().get(table); + if(tableConfig!=null && tableConfig.isGlobalTable()) { + isAllGlobal = true; + } else { + return false; + } + } + return isAllGlobal; + } + + /** + * + * @param schema + * @param ctx + * @param tc + * @return true表示校验通过,false表示检验不通过 + */ + public static boolean checkRuleRequired(SchemaConfig schema, DruidShardingParseInfo ctx, RouteCalculateUnit routeUnit, TableConfig tc) { + if(!tc.isRuleRequired()) { + return true; + } + boolean hasRequiredValue = false; + String tableName = tc.getName(); + if(routeUnit.getTablesAndConditions().get(tableName) == null || routeUnit.getTablesAndConditions().get(tableName).size() == 0) { + hasRequiredValue = false; + } else { + for(Map.Entry> condition : routeUnit.getTablesAndConditions().get(tableName).entrySet()) { + + String colName = condition.getKey(); + //条件字段是拆分字段 + if(colName.equals(tc.getPartitionColumn())) { + hasRequiredValue = true; + break; + } + } + } + return hasRequiredValue; + } + + + /** + * 增加判断支持未配置分片的表走默认的dataNode + * @param schemaConfig + * @param tableName + * @return + */ + public static boolean isNoSharding(SchemaConfig schemaConfig, String tableName) { + // Table名字被转化为大写的,存储在schema + tableName = tableName.toUpperCase(); + if (schemaConfig.isNoSharding()) { + return true; + } + + if (schemaConfig.getDataNode() != null && !schemaConfig.getTables().containsKey(tableName)) { + return true; + } + + return false; + } + + /** + * 系统表判断,某些sql语句会查询系统表或者跟系统表关联 + * @author lian + * @date 2016年12月2日 + * @param tableName + * @return + */ + public static boolean isSystemSchema(String tableName) { + // 以information_schema, mysql开头的是系统表 + if (tableName.startsWith("INFORMATION_SCHEMA.") + || tableName.startsWith("MYSQL.") + || tableName.startsWith("PERFORMANCE_SCHEMA.")) { + return true; + } + + return false; + } + + /** + * 判断条件是否永真 + * @param expr + * @return + */ + public static boolean isConditionAlwaysTrue(SQLExpr expr) { + Object o = WallVisitorUtils.getValue(expr); + if(Boolean.TRUE.equals(o)) { + return true; + } + return false; + } + + /** + * 判断条件是否永假的 + * @param expr + * @return + */ + public static boolean isConditionAlwaysFalse(SQLExpr expr) { + Object o = WallVisitorUtils.getValue(expr); + if(Boolean.FALSE.equals(o)) { + return true; + } + return false; + } + + + /** + * 该方法,返回是否是ER子表 + * @param schema + * @param origSQL + * @param sc + * @return + * @throws SQLNonTransientException + * + * 备注说明: + * edit by ding.w at 2017.4.28, 主要处理 CLIENT_MULTI_STATEMENTS(insert into ; insert into)的情况 + * 目前仅支持mysql,并COM_QUERY请求包中的所有insert语句要么全部是er表,要么全部不是 + * + * + */ + public static boolean processERChildTable(final SchemaConfig schema, final String origSQL, + final ServerConnection sc) throws SQLNonTransientException { + + MySqlStatementParser parser = new MySqlStatementParser(origSQL); + List statements = parser.parseStatementList(); + + if(statements == null || statements.isEmpty() ) { + throw new SQLNonTransientException(String.format("无效的SQL语句:%s", origSQL)); + } + + + boolean erFlag = false; //是否是er表 + for(SQLStatement stmt : statements ) { + MySqlInsertStatement insertStmt = (MySqlInsertStatement) stmt; + String tableName = insertStmt.getTableName().getSimpleName().toUpperCase(); + final TableConfig tc = schema.getTables().get(tableName); + + if (null != tc && tc.isChildTable()) { + erFlag = true; + + String sql = insertStmt.toString(); + + final RouteResultset rrs = new RouteResultset(sql, ServerParse.INSERT); + String joinKey = tc.getJoinKey(); + //因为是Insert语句,用MySqlInsertStatement进行parse +// MySqlInsertStatement insertStmt = (MySqlInsertStatement) (new MySqlStatementParser(origSQL)).parseInsert(); + //判断条件完整性,取得解析后语句列中的joinkey列的index + int joinKeyIndex = getJoinKeyIndex(insertStmt.getColumns(), joinKey); + if (joinKeyIndex == -1) { + String inf = "joinKey not provided :" + tc.getJoinKey() + "," + insertStmt; + LOGGER.warn(inf); + throw new SQLNonTransientException(inf); + } + //子表不支持批量插入 + if (isMultiInsert(insertStmt)) { + String msg = "ChildTable multi insert not provided"; + LOGGER.warn(msg); + throw new SQLNonTransientException(msg); + } + //取得joinkey的值 + String joinKeyVal = insertStmt.getValues().getValues().get(joinKeyIndex).toString(); + //解决bug #938,当关联字段的值为char类型时,去掉前后"'" + String realVal = joinKeyVal; + if (joinKeyVal.startsWith("'") && joinKeyVal.endsWith("'") && joinKeyVal.length() > 2) { + realVal = joinKeyVal.substring(1, joinKeyVal.length() - 1); + } + + + + // try to route by ER parent partion key + //如果是二级子表(父表不再有父表),并且分片字段正好是joinkey字段,调用routeByERParentKey + RouteResultset theRrs = RouterUtil.routeByERParentKey(sc, schema, ServerParse.INSERT, sql, rrs, tc, realVal); + if (theRrs != null) { + boolean processedInsert=false; + //判断是否需要全局序列号 + if ( sc!=null && tc.isAutoIncrement()) { + String primaryKey = tc.getPrimaryKey(); + processedInsert=processInsert(sc,schema,ServerParse.INSERT,sql,tc.getName(),primaryKey); + } + if(processedInsert==false){ + rrs.setFinishedRoute(true); + sc.getSession2().execute(rrs, ServerParse.INSERT); + } + // return true; + //继续处理下一条 + continue; + } + + // route by sql query root parent's datanode + //如果不是二级子表或者分片字段不是joinKey字段结果为空,则启动异步线程去后台分片查询出datanode + //只要查询出上一级表的parentkey字段的对应值在哪个分片即可 + final String findRootTBSql = tc.getLocateRTableKeySql().toLowerCase() + joinKeyVal; + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("find root parent's node sql " + findRootTBSql); + } + + ListenableFuture listenableFuture = MycatServer.getInstance(). + getListeningExecutorService().submit(new Callable() { + @Override + public String call() throws Exception { + FetchStoreNodeOfChildTableHandler fetchHandler = new FetchStoreNodeOfChildTableHandler(); +// return fetchHandler.execute(schema.getName(), findRootTBSql, tc.getRootParent().getDataNodes()); + return fetchHandler.execute(schema.getName(), findRootTBSql, tc.getRootParent().getDataNodes(), sc); + } + }); + + + Futures.addCallback(listenableFuture, new FutureCallback() { + @Override + public void onSuccess(String result) { + //结果为空,证明上一级表中不存在那条记录,失败 + if (Strings.isNullOrEmpty(result)) { + StringBuilder s = new StringBuilder(); + LOGGER.warn(s.append(sc.getSession2()).append(origSQL).toString() + + " err:" + "can't find (root) parent sharding node for sql:" + origSQL); + if(!sc.isAutocommit()) { // 处于事务下失败, 必须回滚 + sc.setTxInterrupt("can't find (root) parent sharding node for sql:" + origSQL); + } + sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR, "can't find (root) parent sharding node for sql:" + origSQL); + return; + } + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("found partion node for child table to insert " + result + " sql :" + origSQL); + } + //找到分片,进行插入(和其他的一样,需要判断是否需要全局自增ID) + boolean processedInsert=false; + if ( sc!=null && tc.isAutoIncrement()) { + try { + String primaryKey = tc.getPrimaryKey(); + processedInsert=processInsert(sc,schema,ServerParse.INSERT,origSQL,tc.getName(),primaryKey); + } catch (SQLNonTransientException e) { + LOGGER.warn("sequence processInsert error,",e); + sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR , "sequence processInsert error," + e.getMessage()); + } + } + if(processedInsert==false){ + RouteResultset executeRrs = RouterUtil.routeToSingleNode(rrs, result, origSQL); + sc.getSession2().execute(executeRrs, ServerParse.INSERT); + } + + } + + @Override + public void onFailure(Throwable t) { + StringBuilder s = new StringBuilder(); + LOGGER.warn(s.append(sc.getSession2()).append(origSQL).toString() + + " err:" + t.getMessage()); + sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR, t.getMessage() + " " + s.toString()); + } + }, MycatServer.getInstance(). + getListeningExecutorService()); + + } else if(erFlag) { + throw new SQLNonTransientException(String.format("%s包含不是ER分片的表", origSQL)); + } + } + + + return erFlag; + } + + /** + * 寻找joinKey的索引 + * + * @param columns + * @param joinKey + * @return -1表示没找到,>=0表示找到了 + */ + private static int getJoinKeyIndex(List columns, String joinKey) { + for (int i = 0; i < columns.size(); i++) { + String col = StringUtil.removeBackquote(columns.get(i).toString()).toUpperCase(); + if (col.equals(joinKey)) { + return i; + } + } + return -1; + } + + /** + * 是否为批量插入:insert into ...values (),()...或 insert into ...select..... + * + * @param insertStmt + * @return + */ + private static boolean isMultiInsert(MySqlInsertStatement insertStmt) { + return (insertStmt.getValuesList() != null && insertStmt.getValuesList().size() > 1) + || insertStmt.getQuery() != null; + } + +} diff --git a/src/test/java/io/mycat/route/DruidMysqlRouteStrategyTest.java b/src/test/java/io/mycat/route/DruidMysqlRouteStrategyTest.java index 1296c4941..0851cb213 100644 --- a/src/test/java/io/mycat/route/DruidMysqlRouteStrategyTest.java +++ b/src/test/java/io/mycat/route/DruidMysqlRouteStrategyTest.java @@ -1,1154 +1,1143 @@ -package io.mycat.route; - -import java.sql.SQLNonTransientException; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Set; - -import org.junit.Before; -import org.junit.Test; - -import com.alibaba.druid.sql.ast.SQLStatement; -import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; - -import io.mycat.MycatServer; -import io.mycat.SimpleCachePool; -import io.mycat.cache.LayerCachePool; -import io.mycat.config.loader.SchemaLoader; -import io.mycat.config.loader.xml.XMLSchemaLoader; -import io.mycat.config.model.SchemaConfig; -import io.mycat.config.model.SystemConfig; -import io.mycat.route.RouteResultset; -import io.mycat.route.RouteResultsetNode; -import io.mycat.route.RouteStrategy; -import io.mycat.route.factory.RouteStrategyFactory; -import io.mycat.server.parser.ServerParse; -import junit.framework.Assert; -import junit.framework.TestCase; - -public class DruidMysqlRouteStrategyTest extends TestCase { - protected Map schemaMap; - protected LayerCachePool cachePool = new SimpleCachePool(); - protected RouteStrategy routeStrategy ; - - public DruidMysqlRouteStrategyTest() { - String schemaFile = "/route/schema.xml"; - String ruleFile = "/route/rule.xml"; - SchemaLoader schemaLoader = new XMLSchemaLoader(schemaFile, ruleFile); - schemaMap = schemaLoader.getSchemas(); - MycatServer.getInstance().getConfig().getSchemas().putAll(schemaMap); - RouteStrategyFactory.init(); - routeStrategy = RouteStrategyFactory.getRouteStrategy("druidparser"); - } - - protected void setUp() throws Exception { - // super.setUp(); - // schemaMap = CobarServer.getInstance().getConfig().getSchemas(); - - } - -// public void testAlias() throws Exception { -// String sql = "SELECT UM.UserId , UM.MenuId ,SM.ParentId ,SM.FullName , SM.Description , SM.Img , SM.NavigateUrl ,SM.FormName ,SM.Target ,SM.IsUnfold FROM Lever_SysMenu SM INNER JOIN ( SELECT UR.UserId AS UserId , RM.MenuId AS MenuId FROM Lever_RoleMenu RM INNER JOIN Lever_UserRole UR ON RM.RoleId = UR.RoleId UNION SELECT UserId , MenuId FROM Lever_UserMenu UNION SELECT U.UserId , RM.MenuId FROM Lever_User U LEFT JOIN Lever_RoleMenu RM ON U.RoleId = RM.RoleId WHERE U.UserId = '8d28533f-1762-4e79-b71f-64eb1a50cb8b' ) UM ON SM.MenuId = UM.MenuId WHERE UM.UserId = '8d28533f-1762-4e79-b71f-64eb1a50cb8b' AND SM.Enabled = 1 ORDER BY SM.SortCode"; -// SchemaConfig schema = schemaMap.get("wdw"); -// RouteResultset rrs = routeStrategy.route(new SystemConfig(),schema, -1, sql, null, -// null, cachePool); -// } - - - public void testRouteInsertShort() throws Exception { - String sql = "inSErt into offer_detail (`offer_id`, gmt) values (123,now())"; - SchemaConfig schema = schemaMap.get("cndb"); - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, - null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1l, rrs.getLimitSize()); - Assert.assertEquals("detail_dn15", rrs.getNodes()[0].getName()); - Assert.assertEquals( - "inSErt into offer_detail (`offer_id`, gmt) values (123,now())", - rrs.getNodes()[0].getStatement()); - - sql = "inSErt into offer_detail ( gmt) values (now())"; - schema = schemaMap.get("cndb"); - try { - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - } catch (Exception e) { - String msg = "bad insert sql (sharding column:"; - Assert.assertTrue(e.getMessage().contains(msg)); - } - sql = "inSErt into offer_detail (offer_id, gmt) values (123,now())"; - schema = schemaMap.get("cndb"); - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1l, rrs.getLimitSize()); - Assert.assertEquals("detail_dn15", rrs.getNodes()[0].getName()); - Assert.assertEquals( - "inSErt into offer_detail (offer_id, gmt) values (123,now())", - rrs.getNodes()[0].getStatement()); - - sql = "insert into offer(group_id,offer_id,member_id)values(234,123,'abc')"; - schema = schemaMap.get("cndb"); - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1l, rrs.getLimitSize()); - Assert.assertEquals("offer_dn12", rrs.getNodes()[0].getName()); - Assert.assertEquals( - "insert into offer(group_id,offer_id,member_id)values(234,123,'abc')", - rrs.getNodes()[0].getStatement()); - - - - - sql = "\n" + - " INSERT INTO \n" + - "`offer` \n" + - "(`asf`,member_id) \n" + - "VALUES \n" + - "(' the articles sfroms user selection ','abc')"; - schema = schemaMap.get("cndb"); - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - - - } - - public void testGlobalTableroute() throws Exception { - String sql = null; - SchemaConfig schema = schemaMap.get("TESTDB"); - RouteResultset rrs = null; - // select of global table route to only one datanode defined - sql = "select * from company where company.name like 'aaa'"; - schema = schemaMap.get("TESTDB"); - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - // query of global table only route to one datanode - sql = "insert into company (id,name,level) values(111,'company1',3)"; - schema = schemaMap.get("TESTDB"); - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(3, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - - // update of global table route to every datanode defined - sql = "update company set name=name+aaa"; - schema = schemaMap.get("TESTDB"); - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(3, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - - // delete of global table route to every datanode defined - sql = "delete from company where id = 1"; - schema = schemaMap.get("TESTDB"); - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(3, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - - // company is global table ,will route to differnt tables - schema = schemaMap.get("TESTDB"); - sql = "select * from company A where a.sharding_id=10001 union select * from company B where B.sharding_id =10010"; - Set nodeSet = new HashSet(); - for (int i = 0; i < 10; i++) { - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(1, rrs.getNodes().length); - nodeSet.add(rrs.getNodes()[0].getName()); - - } - Assert.assertEquals(true, nodeSet.size() > 1); - - } - - public void testMoreGlobalTableroute() throws Exception { - String sql = null; - SchemaConfig schema = schemaMap.get("TESTDB"); - RouteResultset rrs = null; - // select of global table route to only one datanode defined - sql = "select * from company,area where area.company_id=company.id "; - schema = schemaMap.get("TESTDB"); - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); // 全局表涉及到多个节点时,不缓存路由结果 - - } - - public void testRouteMultiTables() throws Exception { - // company is global table ,route to 3 datanode and ignored in route - String sql = "select * from company,customer ,orders where customer.company_id=company.id and orders.customer_id=customer.id and company.name like 'aaa' limit 10"; - SchemaConfig schema = schemaMap.get("TESTDB"); - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, - null, cachePool); - Assert.assertEquals(2, rrs.getNodes().length); - Assert.assertEquals(true, rrs.isCacheAble()); - Assert.assertEquals(10, rrs.getLimitSize()); - Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); - Assert.assertEquals("dn2", rrs.getNodes()[1].getName()); - - } - - public void testRouteCache() throws Exception { - // select cache ID - this.cachePool.putIfAbsent("TESTDB_EMPLOYEE", "88", "dn2"); - - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "select * from employee where id=88"; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, - null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble());//已经缓存了,不必再缓存了 - Assert.assertEquals(null, rrs.getPrimaryKey()); - Assert.assertEquals(-1, rrs.getLimitSize()); - Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); - - // select cache ID not found ,return all node and rrst not cached - sql = "select * from employee where id=89"; - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(2, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals("TESTDB_EMPLOYEE.ID", rrs.getPrimaryKey()); - Assert.assertEquals(-1, rrs.getLimitSize()); - - // update cache ID found - sql = "update employee set name='aaa' where id=88"; - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(null, rrs.getPrimaryKey()); - Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); - - // delete cache ID found - sql = "delete from employee where id=88"; - rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); - - } - - private static Map getNodeMap( - RouteResultset rrs, int expectSize) { - RouteResultsetNode[] routeNodes = rrs.getNodes(); - Assert.assertEquals(expectSize, routeNodes.length); - Map nodeMap = new HashMap( - expectSize, 1); - for (int i = 0; i < expectSize; i++) { - RouteResultsetNode routeNode = routeNodes[i]; - nodeMap.put(routeNode.getName(), routeNode); - } - Assert.assertEquals(expectSize, nodeMap.size()); - return nodeMap; - } - - private static interface NodeNameDeconstructor { - public int getNodeIndex(String name); - } - - private static class NodeNameAsserter implements NodeNameDeconstructor { - private String[] expectNames; - - public NodeNameAsserter() { - } - - public NodeNameAsserter(String... expectNames) { - Assert.assertNotNull(expectNames); - this.expectNames = expectNames; - } - - protected void setNames(String[] expectNames) { - Assert.assertNotNull(expectNames); - this.expectNames = expectNames; - } - - public void assertRouteNodeNames(Collection nodeNames) { - Assert.assertNotNull(nodeNames); - Assert.assertEquals(expectNames.length, nodeNames.size()); - for (String name : expectNames) { - Assert.assertTrue(nodeNames.contains(name)); - } - } - - @Override - public int getNodeIndex(String name) { - for (int i = 0; i < expectNames.length; ++i) { - if (name.equals(expectNames[i])) { - return i; - } - } - throw new NoSuchElementException("route node " + name - + " dosn't exist!"); - } - } - - private static class IndexedNodeNameAsserter extends NodeNameAsserter { - /** - * @param from included - * @param to excluded - */ - public IndexedNodeNameAsserter(String prefix, int from, int to) { - super(); - String[] names = new String[to - from]; - for (int i = 0; i < names.length; ++i) { - names[i] = prefix + (i + from) ; - } - setNames(names); - } - } - - private static class RouteNodeAsserter { - private NodeNameDeconstructor deconstructor; - private SQLAsserter sqlAsserter; - - public RouteNodeAsserter(NodeNameDeconstructor deconstructor, - SQLAsserter sqlAsserter) { - this.deconstructor = deconstructor; - this.sqlAsserter = sqlAsserter; - } - - public void assertNode(RouteResultsetNode node) throws Exception { - int nodeIndex = deconstructor.getNodeIndex(node.getName()); - sqlAsserter.assertSQL(node.getStatement(), nodeIndex); - } - } - - private static interface SQLAsserter { - public void assertSQL(String sql, int nodeIndex) throws Exception; - } - - private static class SimpleSQLAsserter implements SQLAsserter { - private Map> map = new HashMap>(); - - public SimpleSQLAsserter addExpectSQL(int nodeIndex, String sql) { - Set set = map.get(nodeIndex); - if (set == null) { - set = new HashSet(); - map.put(nodeIndex, set); - } - set.add(sql); - return this; - } - - @Override - public void assertSQL(String sql, int nodeIndex) throws Exception { - Assert.assertNotNull(map.get(nodeIndex)); - Assert.assertTrue(map.get(nodeIndex).contains(sql)); - } - } - - public void testroute() throws Exception { - SchemaConfig schema = schemaMap.get("cndb"); - - String sql = "select * from independent where member='abc'"; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - Map nodeMap = getNodeMap(rrs, 128); - IndexedNodeNameAsserter nameAsserter = new IndexedNodeNameAsserter( - "independent_dn", 0, 128); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - SimpleSQLAsserter sqlAsserter = new SimpleSQLAsserter(); - for (int i = 0; i < 128; ++i) { - sqlAsserter.addExpectSQL(i, - "select * from independent where member='abc'"); - } - RouteNodeAsserter asserter = new RouteNodeAsserter(nameAsserter, - sqlAsserter); - for (RouteResultsetNode node : nodeMap.values()) { - asserter.assertNode(node); - } - - // include database schema ,should remove - sql = "select * from cndb.independent A where a.member='abc'"; - schema = schemaMap.get("cndb"); - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - nodeMap = getNodeMap(rrs, 128); - nameAsserter = new IndexedNodeNameAsserter("independent_dn", 0, 128); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - sqlAsserter = new SimpleSQLAsserter(); - for (int i = 0; i < 128; ++i) { - sqlAsserter.addExpectSQL(i, - "select * from independent A where a.member='abc'"); - } - asserter = new RouteNodeAsserter(nameAsserter, sqlAsserter); - for (RouteResultsetNode node : nodeMap.values()) { - asserter.assertNode(node); - } - - } - - public void testERroute() throws Exception { - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "insert into orders (id,name,customer_id) values(1,'testonly',1)"; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); - - sql = "insert into orders (id,name,customer_id) values(1,'testonly',2000001)"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); - - // can't update join key - sql = "update orders set id=1 ,name='aaa' , customer_id=2000001"; - String err = null; - try { - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - } catch (SQLNonTransientException e) { - err = e.getMessage(); - } - Assert.assertEquals( - true, - err.startsWith("Parent relevant column can't be updated ORDERS->CUSTOMER_ID")); - - // route by parent rule ,update sql - sql = "update orders set id=1 ,name='aaa' where customer_id=2000001"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); - - // route by parent rule but can't find datanode - sql = "update orders set id=1 ,name='aaa' where customer_id=-1"; - try { - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - } catch (Exception e) { - err = e.getMessage(); - } - Assert.assertEquals(true, - err.startsWith("can't find datanode for sharding column:")); - - // route by parent rule ,select sql - sql = "select * from orders where customer_id=2000001"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); - - // route by parent rule ,delete sql - sql = "delete from orders where customer_id=2000001"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); - - //test alias in column - sql = "select name as order_name from orders order by order_name limit 10,5"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - MySqlStatementParser parser = new MySqlStatementParser("SELECT name AS order_name FROM orders ORDER BY order_name LIMIT 0,15"); - SQLStatement statement = parser.parseStatement(); - -// Assert.assertEquals(sql, rrs.getNodes()[0].getStatement()); - } - - public void testDuplicatePartitionKey() throws Exception { - String sql = null; - SchemaConfig schema = schemaMap.get("cndb"); - RouteResultset rrs = null; - - sql = "select * from cndb.offer where (offer_id, group_id ) In (123,234)"; - schema = schemaMap.get("cndb"); - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - Assert.assertEquals(-1l, rrs.getLimitSize()); - Assert.assertEquals(128, rrs.getNodes().length); - for (int i = 0; i < 128; i++) { -// Assert.assertEquals("offer_dn" + i , -// rrs.getNodes()[i].getName());//node的排序有变化,所以此处不强求 - Assert.assertEquals( - "select * from offer where (offer_id, group_id ) In (123,234)", - rrs.getNodes()[i].getStatement()); - } - - sql = "SELECT * FROM offer WHERE FALSE OR offer_id = 123 AND member_id = 123 OR member_id = 123 AND member_id = 234 OR member_id = 123 AND member_id = 345 OR member_id = 123 AND member_id = 456 OR offer_id = 234 AND group_id = 123 OR offer_id = 234 AND group_id = 234 OR offer_id = 234 AND group_id = 345 OR offer_id = 234 AND group_id = 456 OR offer_id = 345 AND group_id = 123 OR offer_id = 345 AND group_id = 234 OR offer_id = 345 AND group_id = 345 OR offer_id = 345 AND group_id = 456 OR offer_id = 456 AND group_id = 123 OR offer_id = 456 AND group_id = 234 OR offer_id = 456 AND group_id = 345 OR offer_id = 456 AND group_id = 456"; - schema = schemaMap.get("cndb"); - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - getNodeMap(rrs, 128); - - sql = "select * from offer where false" - + " or offer_id=123 and group_id=123" - + " or group_id=123 and offer_id=234" - + " or offer_id=123 and group_id=345" - + " or offer_id=123 and group_id=456 "; - schema = schemaMap.get("cndb"); - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - Assert.assertEquals(-1l, rrs.getLimitSize()); - - } - - public void testAddLimitToSQL() throws Exception { - final SchemaConfig schema = schemaMap.get("TESTDB"); - - String sql = null; - RouteResultset rrs = null; - - sql = "select * from orders"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - Map nodeMap = getNodeMap(rrs, 2); - NodeNameAsserter nameAsserter = new NodeNameAsserter("dn2", - "dn1"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - Assert.assertEquals(schema.getDefaultMaxLimit(), rrs.getLimitSize()); -// Assert.assertEquals("SELECT * FROM orders LIMIT 100", rrs.getNodes()[0].getStatement()); - MySqlStatementParser parser = new MySqlStatementParser("SELECT * FROM orders LIMIT 100"); - SQLStatement statement = parser.parseStatement(); - Assert.assertEquals(statement.toString(), rrs.getNodes()[0].getStatement()); - - - sql = "select * from goods"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(schema.getDefaultMaxLimit(), rrs.getLimitSize()); -// Assert.assertEquals("select * from goods", rrs.getNodes()[0].getStatement()); - parser = new MySqlStatementParser("SELECT * FROM goods LIMIT 100"); - statement = parser.parseStatement(); - Assert.assertEquals(statement.toString(), rrs.getNodes()[0].getStatement()); - - - sql = "select * from goods limit 2 ,3"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(1, rrs.getNodes().length); -// Assert.assertEquals(-1, rrs.getLimitSize()); - Assert.assertEquals("select * from goods limit 2 ,3", rrs.getNodes()[0].getStatement()); - - - sql = "select * from notpartionTable limit 2 ,3"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals(3, rrs.getLimitSize()); - Assert.assertEquals("select * from notpartionTable limit 2 ,3", rrs.getNodes()[0].getStatement()); - - } - - - public void testModifySQLLimit() throws Exception { - final SchemaConfig schema = schemaMap.get("TESTDB"); - - String sql = null; - RouteResultset rrs = null; - //SQL span multi datanode - sql = "select * from orders limit 2,3"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - Map nodeMap = getNodeMap(rrs, 2); - NodeNameAsserter nameAsserter = new NodeNameAsserter("dn2", - "dn1"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - Assert.assertEquals(3, rrs.getLimitSize()); - - - MySqlStatementParser parser = new MySqlStatementParser("SELECT * FROM orders LIMIT 0,5"); - SQLStatement statement = parser.parseStatement(); - - Assert.assertEquals(statement.toString(), rrs.getNodes()[0].getStatement()); - - //SQL not span multi datanode - sql = "select * from customer where id=10000 limit 2,3"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - nodeMap = getNodeMap(rrs, 1); - nameAsserter = new NodeNameAsserter("dn1"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - Assert.assertEquals(3, rrs.getLimitSize()); - Assert.assertEquals("select * from customer where id=10000 limit 2,3", rrs.getNodes()[0].getStatement()); - - - } - - public void testGroupLimit() throws Exception { - final SchemaConfig schema = schemaMap.get("cndb"); - - String sql = null; - RouteResultset rrs = null; - - sql = "select count(*) from (select * from(select * from offer_detail where offer_id='123' or offer_id='234' limit 88)offer where offer.member_id='abc' limit 60) w " - + " where w.member_id ='pavarotti17' limit 99"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - // Assert.assertEquals(88L, rrs.getLimitSize()); - // Assert.assertEquals(RouteResultset.SUM_FLAG, rrs.getFlag()); - Map nodeMap = getNodeMap(rrs, 2); - NodeNameAsserter nameAsserter = new NodeNameAsserter("detail_dn29", - "detail_dn15"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - - sql = "select count(*) from (select * from(select max(id) from offer_detail where offer_id='123' or offer_id='234' limit 88)offer where offer.member_id='abc' limit 60) w " - + " where w.member_id ='pavarotti17' limit 99"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - nodeMap = getNodeMap(rrs, 2); - nameAsserter = new NodeNameAsserter("detail_dn29", "detail_dn15"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - - sql = "select * from (select * from(select max(id) from offer_detail where offer_id='123' or offer_id='234' limit 88)offer where offer.member_id='abc' limit 60) w " - + " where w.member_id ='pavarotti17' limit 99"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - nodeMap = getNodeMap(rrs, 2); - nameAsserter = new NodeNameAsserter("detail_dn29", "detail_dn15"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - - sql = "select * from (select count(*) from(select * from offer_detail where offer_id='123' or offer_id='234' limit 88)offer where offer.member_id='abc' limit 60) w " - + " where w.member_id ='pavarotti17' limit 99"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(true, rrs.isCacheAble()); - // Assert.assertEquals(88L, rrs.getLimitSize()); - // Assert.assertEquals(RouteResultset.SUM_FLAG, rrs.getFlag()); - nodeMap = getNodeMap(rrs, 2); - nameAsserter = new NodeNameAsserter("detail_dn29", "detail_dn15"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - - } - - public void testTableMetaRead() throws Exception { - final SchemaConfig schema = schemaMap.get("cndb"); - - String sql = " desc offer"; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.DESCRIBE, sql, null, null, - cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(1, rrs.getNodes().length); - // random return one node - // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); - Assert.assertEquals("desc offer", rrs.getNodes()[0].getStatement()); - - sql = "desc cndb.offer"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.DESCRIBE, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(1, rrs.getNodes().length); - // random return one node - // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); - Assert.assertEquals("desc offer", rrs.getNodes()[0].getStatement()); - - sql = "desc cndb.offer col1"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.DESCRIBE, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(1, rrs.getNodes().length); - // random return one node - // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); - Assert.assertEquals("desc offer col1", rrs.getNodes()[0].getStatement()); - - sql = "SHOW FULL COLUMNS FROM offer IN db_name WHERE true"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, - cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(1, rrs.getNodes().length); - // random return one node - // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); - Assert.assertEquals("SHOW FULL COLUMNS FROM offer WHERE true", - rrs.getNodes()[0].getStatement()); - - sql = "SHOW FULL COLUMNS FROM db.offer IN db_name WHERE true"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, - cachePool); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(1, rrs.getNodes().length); - // random return one node - // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); - Assert.assertEquals("SHOW FULL COLUMNS FROM offer WHERE true", - rrs.getNodes()[0].getStatement()); - - - sql = "SHOW FULL TABLES FROM `TESTDB` WHERE Table_type != 'VIEW'"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, - cachePool); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals("SHOW FULL TABLES WHERE Table_type != 'VIEW'", rrs.getNodes()[0].getStatement()); - - sql = "SHOW INDEX IN offer FROM db_name"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, - cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(1, rrs.getNodes().length); - // random return one node - // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); - Assert.assertEquals("SHOW INDEX FROM offer", - rrs.getNodes()[0].getStatement()); - - sql = "SHOW TABLES from db_name like 'solo'"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, - cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Map nodeMap = getNodeMap(rrs, 3); - NodeNameAsserter nameAsserter = new NodeNameAsserter("detail_dn0", - "offer_dn0", "independent_dn0"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - SimpleSQLAsserter sqlAsserter = new SimpleSQLAsserter(); - sqlAsserter.addExpectSQL(0, "SHOW TABLES like 'solo'") - .addExpectSQL(1, "SHOW TABLES like 'solo'") - .addExpectSQL(2, "SHOW TABLES like 'solo'") - .addExpectSQL(3, "SHOW TABLES like 'solo'"); - RouteNodeAsserter asserter = new RouteNodeAsserter(nameAsserter, - sqlAsserter); - for (RouteResultsetNode node : nodeMap.values()) { - asserter.assertNode(node); - } - - sql = "SHOW TABLES in db_name "; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, - cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - nodeMap = getNodeMap(rrs, 3); - nameAsserter = new NodeNameAsserter("detail_dn0", "offer_dn0", - "independent_dn0"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - sqlAsserter = new SimpleSQLAsserter(); - sqlAsserter.addExpectSQL(0, "SHOW TABLES") - .addExpectSQL(1, "SHOW TABLES").addExpectSQL(2, "SHOW TABLES") - .addExpectSQL(3, "SHOW TABLES"); - asserter = new RouteNodeAsserter(nameAsserter, sqlAsserter); - for (RouteResultsetNode node : nodeMap.values()) { - asserter.assertNode(node); - } - - sql = "SHOW TABLeS "; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, - cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - nodeMap = getNodeMap(rrs, 3); - nameAsserter = new NodeNameAsserter("offer_dn0", "detail_dn0", - "independent_dn0"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - sqlAsserter = new SimpleSQLAsserter(); - sqlAsserter.addExpectSQL(0, "SHOW TABLeS ") - .addExpectSQL(1, "SHOW TABLeS ").addExpectSQL(2, "SHOW TABLeS "); - asserter = new RouteNodeAsserter(nameAsserter, sqlAsserter); - for (RouteResultsetNode node : nodeMap.values()) { - asserter.assertNode(node); - } - } - - public void testConfigSchema() throws Exception { - try { - SchemaConfig schema = schemaMap.get("config"); - String sql = "select * from offer where offer_id=1"; - routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertFalse(true); - } catch (Exception e) { - Assert.assertEquals("route rule for table OFFER is required: select * from offer where offer_id=1", e.getMessage()); - } - try { - SchemaConfig schema = schemaMap.get("config"); - String sql = "select * from offer where col11111=1"; - routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertFalse(true); - } catch (Exception e) { - } - try { - SchemaConfig schema = schemaMap.get("config"); - String sql = "select * from offer "; - routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertFalse(true); - } catch (Exception e) { - } - } - - public void testIgnoreSchema() throws Exception { - SchemaConfig schema = schemaMap.get("ignoreSchemaTest"); - String sql = "select * from offer where offer_id=1"; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals("cndb_dn", rrs.getNodes()[0].getName()); - Assert.assertEquals(sql, rrs.getNodes()[0].getStatement()); - sql = "select * from ignoreSchemaTest.offer1 where ignoreSchemaTest.offer1.offer_id=1"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals("select * from offer1 where offer1.offer_id=1", - rrs.getNodes()[0].getStatement()); - sql = "select * from ignoreSchemaTest2.offer where ignoreSchemaTest2.offer.offer_id=1"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(sql, rrs.getNodes()[0].getStatement(), sql); - sql = "select * from ignoreSchemaTest2.offer a,offer b where ignoreSchemaTest2.offer.offer_id=1"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals( - "select * from ignoreSchemaTest2.offer a,offer b where ignoreSchemaTest2.offer.offer_id=1", - rrs.getNodes()[0].getStatement()); - - } - - public void testNonPartitionSQL() throws Exception { - - SchemaConfig schema = schemaMap.get("cndb"); - String sql = null; - RouteResultset rrs = null; - - schema = schemaMap.get("dubbo2"); - sql = "SHOW TABLES from db_name like 'solo'"; - rrs = routeStrategy.route(new SystemConfig(), schema, 9, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); - Assert.assertEquals("SHOW TABLES like 'solo'", - rrs.getNodes()[0].getStatement()); - - schema = schemaMap.get("dubbo"); - sql = "SHOW TABLES from db_name like 'solo'"; - rrs = routeStrategy.route(new SystemConfig(), schema, 9, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals("dubbo_dn", rrs.getNodes()[0].getName()); - Assert.assertEquals("SHOW TABLES like 'solo'", - rrs.getNodes()[0].getStatement()); - - - - sql = "desc cndb.offer"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals(-1L, rrs.getLimitSize()); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals("dubbo_dn", rrs.getNodes()[0].getName()); - Assert.assertEquals("desc cndb.offer", rrs.getNodes()[0].getStatement()); - - schema = schemaMap.get("cndb"); - sql = "SHOW fulL TaBLES from db_name like 'solo'"; - rrs = routeStrategy.route(new SystemConfig(), schema, 9, sql, null, null, cachePool); - Assert.assertEquals(false, rrs.isCacheAble()); - Map nodeMap = getNodeMap(rrs, 3); - NodeNameAsserter nameAsserter = new NodeNameAsserter("detail_dn0", - "offer_dn0", "independent_dn0"); - nameAsserter.assertRouteNodeNames(nodeMap.keySet()); - SimpleSQLAsserter sqlAsserter = new SimpleSQLAsserter(); - sqlAsserter.addExpectSQL(0, "SHOW FULL TABLES like 'solo'") - .addExpectSQL(1, "SHOW FULL TABLES like 'solo'") - .addExpectSQL(2, "SHOW FULL TABLES like 'solo'") - .addExpectSQL(3, "SHOW FULL TABLES like 'solo'"); - RouteNodeAsserter asserter = new RouteNodeAsserter(nameAsserter, - sqlAsserter); - for (RouteResultsetNode node : nodeMap.values()) { - asserter.assertNode(node); - } - } - - public void testGlobalTableSingleNodeLimit() throws Exception { - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "select * from globalsn"; - RouteResultset rrs = null; - rrs = routeStrategy.route(new SystemConfig(), schema, - ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(100L, rrs.getLimitSize()); - } - - /** - * select 1 - * select 1 union all select 2 - * - * @throws Exception - */ - public void testSelectNoTable() throws Exception { - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "select 1"; - RouteResultset rrs = null; - rrs = routeStrategy.route(new SystemConfig(), schema, - ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - - - sql = "select 1 union select 2"; - rrs = routeStrategy.route(new SystemConfig(), schema, - ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - } - - /** - * 支持insert into ... values (),()... - * 不支持insert into ... select... - * - * @throws Exception - */ - public void testBatchInsert() throws Exception { - - SchemaConfig schema = schemaMap.get("TESTDB"); - RouteResultset rrs = null; - //不支持childtable 批量插入 - String sql = "insert into orders (id,name,customer_id) values(1,'testonly',1),(2,'testonly',2000001)"; - try { - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - } catch (Exception e) { - Assert.assertEquals("ChildTable multi insert not provided", e.getMessage()); - } - - sql = "insert into employee (id,name,customer_id) select id,name,customer_id from customer"; - try { - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - } catch (Exception e) { - Assert.assertEquals("TODO:insert into .... select .... not supported!", e.getMessage()); - } - - //分片表批量插入正常 employee - sql = "insert into employee (id,name,sharding_id) values(1,'testonly',10000),(2,'testonly',10010)"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - Assert.assertEquals(2, rrs.getNodes().length); - Assert.assertEquals(false, rrs.isCacheAble()); - Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); - Assert.assertEquals("dn2", rrs.getNodes()[1].getName()); - String node1Sql = formatSql("insert into employee (id,name,sharding_id) values(1,'testonly',10000)"); - String node2Sql = formatSql("insert into employee (id,name,sharding_id) values(2,'testonly',10010)"); - RouteResultsetNode[] nodes = rrs.getNodes(); - Assert.assertEquals(node1Sql, nodes[0].getStatement()); - Assert.assertEquals(node2Sql, nodes[1].getStatement()); - } - - /** - * insert ... on duplicate key ... update... - * - * @throws Exception - */ - public void testInsertOnDuplicateKey() throws Exception { - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "insert into employee (id,name,sharding_id) values(1,'testonly',10000) on duplicate key update name='nihao'"; - RouteResultset rrs = null; - rrs = routeStrategy.route(new SystemConfig(), schema, - ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); - - //insert ... on duplicate key ... update col1 = VALUES(col1),col2 = VALUES(col2) - sql = "insert into employee (id,name,sharding_id) values(1,'testonly',10000) " + - "on duplicate key update name=VALUES(name),id = VALUES(id)"; - rrs = routeStrategy.route(new SystemConfig(), schema, - ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); - - //insert ... on duplicate key ,sharding key can't be updated - sql = "insert into employee (id,name,sharding_id) values(1,'testonly',10000) " + - "on duplicate key update name=VALUES(name),id = VALUES(id),sharding_id = VALUES(sharding_id)"; - - try { - rrs = routeStrategy.route(new SystemConfig(), schema, - ServerParse.SELECT, sql, null, null, cachePool); - } catch (Exception e) { - Assert.assertEquals("Sharding column can't be updated: EMPLOYEE -> SHARDING_ID", e.getMessage()); - } - - - } - - /** - * 测试函数COUNT - * - * @throws Exception - */ - @Test - public void testAggregateExpr() throws Exception { - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "select id, name, count(name) from employee group by name;"; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getMergeCols().containsKey("COUNT2")); - - sql = "select id, name, count(name) as c from employee group by name;"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getMergeCols().containsKey("c")); - - sql = "select id, name, count(name) c from employee group by name;"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getMergeCols().containsKey("c")); - } - - /** - * 测试between语句的路由 - * - * @throws Exception - */ - @Test - public void testBetweenExpr() throws Exception { -// 0-200M=0 -// 200M1-400M=1 -// 400M1-600M=2 -// 600M1-800M=3 -// 800M1-1000M=4 - - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "select * from customer where id between 1 and 5;"; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes().length == 1); - Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); - - sql = "select * from customer where id between 1 and 2000001;"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes().length == 2); - - sql = "select * from customer where id between 2000001 and 3000001;"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes().length == 1); - Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn2")); - - sql = "delete from customer where id between 2000001 and 3000001;"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes().length == 1); - Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn2")); - - sql = "update customer set name='newName' where id between 2000001 and 3000001;"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes().length == 1); - Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn2")); - - } - - /** - * 测试or语句的路由 - * - * @throws Exception - */ - @Test - public void testOr() throws Exception { -// 0-200M=0 -// 200M1-400M=1 -// 400M1-600M=2 -// 600M1-800M=3 -// 800M1-1000M=4 - - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "select * from customer where sharding_id=10000 or 1=1;"; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes().length == 2); - Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); - Assert.assertTrue(rrs.getNodes()[1].getName().equals("dn2")); - - sql = "select * from customer where sharding_id = 10000 or sharding_id = 10010"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); - Assert.assertTrue(rrs.getNodes()[1].getName().equals("dn2")); - - sql = "select * from customer where sharding_id = 10000 or user_id = 'wangwu'"; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); - Assert.assertTrue(rrs.getNodes()[1].getName().equals("dn2")); - } - - /** - * 测试父子表,查询子表的语句路由到多个节点 - * @throws Exception - */ - @Test - public void testERRouteMutiNode() throws Exception { - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "select * from orders where customer_id in(1,2000001);"; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes().length == 2); - Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); - Assert.assertTrue(rrs.getNodes()[1].getName().equals("dn2")); - } - - /** - * 测试多层or语句 - * - * @throws Exception - */ - @Test - public void testMultiLevelOr() throws Exception { - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "select id from travelrecord " - + " where id = 1 and ( fee=3 or days=5 or (traveldate = '2015-05-04 00:00:07.375' " - + " and (user_id=2 or fee=days or fee = 0))) and name = 'zhangsan'" ; - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - - Assert.assertTrue(rrs.getNodes().length == 1); - - sql = "select id from travelrecord " - + " where id = 1 and ( fee=3 or days=5 or (traveldate = '2015-05-04 00:00:07.375' " - + " and (user_id=2 or fee=days or fee = 0))) and name = 'zhangsan' or id = 2000001" ; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - - Assert.assertTrue(rrs.getNodes().length == 2); - - sql = "select id from travelrecord " - + " where id = 1 and ( fee=3 or days=5 or (traveldate = '2015-05-04 00:00:07.375' " - + " and (user_id=2 or fee=days or fee = 0))) and name = 'zhangsan' or id = 2000001 or id = 4000001" ; - rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - - Assert.assertTrue(rrs.getNodes().length == 3); - } - - /** - * 测试 global table 的or语句 - * - * - * @throws Exception - */ - @Test - public void testGlobalTableOr() throws Exception { - SchemaConfig schema = schemaMap.get("TESTDB"); - String sql = "select id from company where 1 = 1 and name ='company1' or name = 'company2'" ; - for(int i = 0; i < 20; i++) { - RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); - Assert.assertTrue(rrs.getNodes().length == 1); - } - } - - /** - * 测试别名路由 - * - * @throws Exception - */ - public void testAlias() throws Exception { - - SchemaConfig schema = schemaMap.get("TESTDB"); - RouteResultset rrs = null; - //不支持childtable 批量插入 - //update 全局表 - String sql = "update company a set name = '' where a.id = 1;"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - - Assert.assertEquals(3, rrs.getNodes().length); - - //update带别名时的路由 - sql = "update travelrecord a set name = '' where a.id = 1;"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - - //别名大小写路由 - sql = "select * from travelrecord A where a.id = 1;"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - } - - public void testGroupAlias() throws Exception { - - SchemaConfig schema = schemaMap.get("TESTDB"); - RouteResultset rrs = null; - //别名大小写路由 - String sql = "select * from travelrecord A group by a.id = 1;"; - rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, - cachePool); - Assert.assertEquals(1, rrs.getNodes().length); - } - - private String formatSql(String sql) { - MySqlStatementParser parser = new MySqlStatementParser(sql); - SQLStatement stmt = parser.parseStatement(); - return stmt.toString(); - } - - -} +package io.mycat.route; + +import java.sql.SQLNonTransientException; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; + +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; + +import io.mycat.MycatServer; +import io.mycat.SimpleCachePool; +import io.mycat.cache.LayerCachePool; +import io.mycat.config.loader.SchemaLoader; +import io.mycat.config.loader.xml.XMLSchemaLoader; +import io.mycat.config.model.SchemaConfig; +import io.mycat.config.model.SystemConfig; +import io.mycat.route.RouteResultset; +import io.mycat.route.RouteResultsetNode; +import io.mycat.route.RouteStrategy; +import io.mycat.route.factory.RouteStrategyFactory; +import io.mycat.server.parser.ServerParse; +import junit.framework.Assert; +import junit.framework.TestCase; + +public class DruidMysqlRouteStrategyTest extends TestCase { + protected Map schemaMap; + protected LayerCachePool cachePool = new SimpleCachePool(); + protected RouteStrategy routeStrategy ; + + public DruidMysqlRouteStrategyTest() { + String schemaFile = "/route/schema.xml"; + String ruleFile = "/route/rule.xml"; + SchemaLoader schemaLoader = new XMLSchemaLoader(schemaFile, ruleFile); + schemaMap = schemaLoader.getSchemas(); + MycatServer.getInstance().getConfig().getSchemas().putAll(schemaMap); + RouteStrategyFactory.init(); + routeStrategy = RouteStrategyFactory.getRouteStrategy("druidparser"); + } + + protected void setUp() throws Exception { + // super.setUp(); + // schemaMap = CobarServer.getInstance().getConfig().getSchemas(); + + } + +// public void testAlias() throws Exception { +// String sql = "SELECT UM.UserId , UM.MenuId ,SM.ParentId ,SM.FullName , SM.Description , SM.Img , SM.NavigateUrl ,SM.FormName ,SM.Target ,SM.IsUnfold FROM Lever_SysMenu SM INNER JOIN ( SELECT UR.UserId AS UserId , RM.MenuId AS MenuId FROM Lever_RoleMenu RM INNER JOIN Lever_UserRole UR ON RM.RoleId = UR.RoleId UNION SELECT UserId , MenuId FROM Lever_UserMenu UNION SELECT U.UserId , RM.MenuId FROM Lever_User U LEFT JOIN Lever_RoleMenu RM ON U.RoleId = RM.RoleId WHERE U.UserId = '8d28533f-1762-4e79-b71f-64eb1a50cb8b' ) UM ON SM.MenuId = UM.MenuId WHERE UM.UserId = '8d28533f-1762-4e79-b71f-64eb1a50cb8b' AND SM.Enabled = 1 ORDER BY SM.SortCode"; +// SchemaConfig schema = schemaMap.get("wdw"); +// RouteResultset rrs = routeStrategy.route(new SystemConfig(),schema, -1, sql, null, +// null, cachePool); +// } + + + public void testRouteInsertShort() throws Exception { + String sql = "inSErt into offer_detail (`offer_id`, gmt) values (123,now())"; + SchemaConfig schema = schemaMap.get("cndb"); + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, + null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1l, rrs.getLimitSize()); + Assert.assertEquals("detail_dn15", rrs.getNodes()[0].getName()); + Assert.assertEquals( + "inSErt into offer_detail (`offer_id`, gmt) values (123,now())", + rrs.getNodes()[0].getStatement()); + + sql = "inSErt into offer_detail ( gmt) values (now())"; + schema = schemaMap.get("cndb"); + try { + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + } catch (Exception e) { + String msg = "bad insert sql (sharding column:"; + Assert.assertTrue(e.getMessage().contains(msg)); + } + sql = "inSErt into offer_detail (offer_id, gmt) values (123,now())"; + schema = schemaMap.get("cndb"); + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1l, rrs.getLimitSize()); + Assert.assertEquals("detail_dn15", rrs.getNodes()[0].getName()); + Assert.assertEquals( + "inSErt into offer_detail (offer_id, gmt) values (123,now())", + rrs.getNodes()[0].getStatement()); + + sql = "insert into offer(group_id,offer_id,member_id)values(234,123,'abc')"; + schema = schemaMap.get("cndb"); + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1l, rrs.getLimitSize()); + Assert.assertEquals("offer_dn12", rrs.getNodes()[0].getName()); + Assert.assertEquals( + "insert into offer(group_id,offer_id,member_id)values(234,123,'abc')", + rrs.getNodes()[0].getStatement()); + + + + + sql = "\n" + + " INSERT INTO \n" + + "`offer` \n" + + "(`asf`,member_id) \n" + + "VALUES \n" + + "(' the articles sfroms user selection ','abc')"; + schema = schemaMap.get("cndb"); + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + + + } + + public void testGlobalTableroute() throws Exception { + String sql = null; + SchemaConfig schema = schemaMap.get("TESTDB"); + RouteResultset rrs = null; + // select of global table route to only one datanode defined + sql = "select * from company where company.name like 'aaa'"; + schema = schemaMap.get("TESTDB"); + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + // query of global table only route to one datanode + sql = "insert into company (id,name,level) values(111,'company1',3)"; + schema = schemaMap.get("TESTDB"); + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(3, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + + // update of global table route to every datanode defined + sql = "update company set name=name+aaa"; + schema = schemaMap.get("TESTDB"); + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(3, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + + // delete of global table route to every datanode defined + sql = "delete from company where id = 1"; + schema = schemaMap.get("TESTDB"); + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(3, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + + // company is global table ,will route to differnt tables + schema = schemaMap.get("TESTDB"); + sql = "select * from company A where a.sharding_id=10001 union select * from company B where B.sharding_id =10010"; + Set nodeSet = new HashSet(); + for (int i = 0; i < 10; i++) { + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(1, rrs.getNodes().length); + nodeSet.add(rrs.getNodes()[0].getName()); + + } + Assert.assertEquals(true, nodeSet.size() > 1); + + } + + public void testMoreGlobalTableroute() throws Exception { + String sql = null; + SchemaConfig schema = schemaMap.get("TESTDB"); + RouteResultset rrs = null; + // select of global table route to only one datanode defined + sql = "select * from company,area where area.company_id=company.id "; + schema = schemaMap.get("TESTDB"); + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); // 全局表涉及到多个节点时,不缓存路由结果 + + } + + public void testRouteMultiTables() throws Exception { + // company is global table ,route to 3 datanode and ignored in route + String sql = "select * from company,customer ,orders where customer.company_id=company.id and orders.customer_id=customer.id and company.name like 'aaa' limit 10"; + SchemaConfig schema = schemaMap.get("TESTDB"); + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, + null, cachePool); + Assert.assertEquals(2, rrs.getNodes().length); + Assert.assertEquals(true, rrs.isCacheAble()); + Assert.assertEquals(10, rrs.getLimitSize()); + Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); + Assert.assertEquals("dn2", rrs.getNodes()[1].getName()); + + } + + public void testRouteCache() throws Exception { + // select cache ID + this.cachePool.putIfAbsent("TESTDB_EMPLOYEE", "88", "dn2"); + + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "select * from employee where id=88"; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, + null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble());//已经缓存了,不必再缓存了 + Assert.assertEquals(null, rrs.getPrimaryKey()); + Assert.assertEquals(-1, rrs.getLimitSize()); + Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); + + // select cache ID not found ,return all node and rrst not cached + sql = "select * from employee where id=89"; + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(2, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals("TESTDB_EMPLOYEE.ID", rrs.getPrimaryKey()); + Assert.assertEquals(-1, rrs.getLimitSize()); + + // update cache ID found + sql = "update employee set name='aaa' where id=88"; + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(null, rrs.getPrimaryKey()); + Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); + + // delete cache ID found + sql = "delete from employee where id=88"; + rrs = routeStrategy.route(new SystemConfig(), schema, -1, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); + + } + + private static Map getNodeMap( + RouteResultset rrs, int expectSize) { + RouteResultsetNode[] routeNodes = rrs.getNodes(); + Assert.assertEquals(expectSize, routeNodes.length); + Map nodeMap = new HashMap( + expectSize, 1); + for (int i = 0; i < expectSize; i++) { + RouteResultsetNode routeNode = routeNodes[i]; + nodeMap.put(routeNode.getName(), routeNode); + } + Assert.assertEquals(expectSize, nodeMap.size()); + return nodeMap; + } + + private static interface NodeNameDeconstructor { + public int getNodeIndex(String name); + } + + private static class NodeNameAsserter implements NodeNameDeconstructor { + private String[] expectNames; + + public NodeNameAsserter() { + } + + public NodeNameAsserter(String... expectNames) { + Assert.assertNotNull(expectNames); + this.expectNames = expectNames; + } + + protected void setNames(String[] expectNames) { + Assert.assertNotNull(expectNames); + this.expectNames = expectNames; + } + + public void assertRouteNodeNames(Collection nodeNames) { + Assert.assertNotNull(nodeNames); + Assert.assertEquals(expectNames.length, nodeNames.size()); + for (String name : expectNames) { + Assert.assertTrue(nodeNames.contains(name)); + } + } + + @Override + public int getNodeIndex(String name) { + for (int i = 0; i < expectNames.length; ++i) { + if (name.equals(expectNames[i])) { + return i; + } + } + throw new NoSuchElementException("route node " + name + + " dosn't exist!"); + } + } + + private static class IndexedNodeNameAsserter extends NodeNameAsserter { + /** + * @param from included + * @param to excluded + */ + public IndexedNodeNameAsserter(String prefix, int from, int to) { + super(); + String[] names = new String[to - from]; + for (int i = 0; i < names.length; ++i) { + names[i] = prefix + (i + from) ; + } + setNames(names); + } + } + + private static class RouteNodeAsserter { + private NodeNameDeconstructor deconstructor; + private SQLAsserter sqlAsserter; + + public RouteNodeAsserter(NodeNameDeconstructor deconstructor, + SQLAsserter sqlAsserter) { + this.deconstructor = deconstructor; + this.sqlAsserter = sqlAsserter; + } + + public void assertNode(RouteResultsetNode node) throws Exception { + int nodeIndex = deconstructor.getNodeIndex(node.getName()); + sqlAsserter.assertSQL(node.getStatement(), nodeIndex); + } + } + + private static interface SQLAsserter { + public void assertSQL(String sql, int nodeIndex) throws Exception; + } + + private static class SimpleSQLAsserter implements SQLAsserter { + private Map> map = new HashMap>(); + + public SimpleSQLAsserter addExpectSQL(int nodeIndex, String sql) { + Set set = map.get(nodeIndex); + if (set == null) { + set = new HashSet(); + map.put(nodeIndex, set); + } + set.add(sql); + return this; + } + + @Override + public void assertSQL(String sql, int nodeIndex) throws Exception { + Assert.assertNotNull(map.get(nodeIndex)); + Assert.assertTrue(map.get(nodeIndex).contains(sql)); + } + } + + public void testroute() throws Exception { + SchemaConfig schema = schemaMap.get("cndb"); + + String sql = "select * from independent where member='abc'"; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, + cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + Map nodeMap = getNodeMap(rrs, 128); + IndexedNodeNameAsserter nameAsserter = new IndexedNodeNameAsserter( + "independent_dn", 0, 128); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + SimpleSQLAsserter sqlAsserter = new SimpleSQLAsserter(); + for (int i = 0; i < 128; ++i) { + sqlAsserter.addExpectSQL(i, + "select * from independent where member='abc'"); + } + RouteNodeAsserter asserter = new RouteNodeAsserter(nameAsserter, + sqlAsserter); + for (RouteResultsetNode node : nodeMap.values()) { + asserter.assertNode(node); + } + + // include database schema ,should remove + sql = "select * from cndb.independent A where a.member='abc'"; + schema = schemaMap.get("cndb"); + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + nodeMap = getNodeMap(rrs, 128); + nameAsserter = new IndexedNodeNameAsserter("independent_dn", 0, 128); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + sqlAsserter = new SimpleSQLAsserter(); + for (int i = 0; i < 128; ++i) { + sqlAsserter.addExpectSQL(i, + "select * from independent A where a.member='abc'"); + } + asserter = new RouteNodeAsserter(nameAsserter, sqlAsserter); + for (RouteResultsetNode node : nodeMap.values()) { + asserter.assertNode(node); + } + + } + + public void testERroute() throws Exception { + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "insert into orders (id,name,customer_id) values(1,'testonly',1)"; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, + cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); + + sql = "insert into orders (id,name,customer_id) values(1,'testonly',2000001)"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); + + // can't update join key + sql = "update orders set id=1 ,name='aaa' , customer_id=2000001"; + String err = null; + try { + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + } catch (SQLNonTransientException e) { + err = e.getMessage(); + } + Assert.assertEquals( + true, + err.startsWith("Parent relevant column can't be updated ORDERS->CUSTOMER_ID")); + + // route by parent rule ,update sql + sql = "update orders set id=1 ,name='aaa' where customer_id=2000001"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); + + // route by parent rule but can't find datanode + sql = "update orders set id=1 ,name='aaa' where customer_id=-1"; + try { + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + } catch (Exception e) { + err = e.getMessage(); + } + Assert.assertEquals(true, + err.startsWith("can't find datanode for sharding column:")); + + // route by parent rule ,select sql + sql = "select * from orders where customer_id=2000001"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); + + // route by parent rule ,delete sql + sql = "delete from orders where customer_id=2000001"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals("dn2", rrs.getNodes()[0].getName()); + + //test alias in column + sql = "select name as order_name from orders order by order_name limit 10,5"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + MySqlStatementParser parser = new MySqlStatementParser("SELECT name AS order_name FROM orders ORDER BY order_name LIMIT 0,15"); + SQLStatement statement = parser.parseStatement(); + +// Assert.assertEquals(sql, rrs.getNodes()[0].getStatement()); + } + + public void testDuplicatePartitionKey() throws Exception { + String sql = null; + SchemaConfig schema = schemaMap.get("cndb"); + RouteResultset rrs = null; + + sql = "select * from cndb.offer where (offer_id, group_id ) In (123,234)"; + schema = schemaMap.get("cndb"); + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + Assert.assertEquals(-1l, rrs.getLimitSize()); + Assert.assertEquals(128, rrs.getNodes().length); + for (int i = 0; i < 128; i++) { +// Assert.assertEquals("offer_dn" + i , +// rrs.getNodes()[i].getName());//node的排序有变化,所以此处不强求 + Assert.assertEquals( + "select * from offer where (offer_id, group_id ) In (123,234)", + rrs.getNodes()[i].getStatement()); + } + + sql = "SELECT * FROM offer WHERE FALSE OR offer_id = 123 AND member_id = 123 OR member_id = 123 AND member_id = 234 OR member_id = 123 AND member_id = 345 OR member_id = 123 AND member_id = 456 OR offer_id = 234 AND group_id = 123 OR offer_id = 234 AND group_id = 234 OR offer_id = 234 AND group_id = 345 OR offer_id = 234 AND group_id = 456 OR offer_id = 345 AND group_id = 123 OR offer_id = 345 AND group_id = 234 OR offer_id = 345 AND group_id = 345 OR offer_id = 345 AND group_id = 456 OR offer_id = 456 AND group_id = 123 OR offer_id = 456 AND group_id = 234 OR offer_id = 456 AND group_id = 345 OR offer_id = 456 AND group_id = 456"; + schema = schemaMap.get("cndb"); + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + getNodeMap(rrs, 128); + + sql = "select * from offer where false" + + " or offer_id=123 and group_id=123" + + " or group_id=123 and offer_id=234" + + " or offer_id=123 and group_id=345" + + " or offer_id=123 and group_id=456 "; + schema = schemaMap.get("cndb"); + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + Assert.assertEquals(-1l, rrs.getLimitSize()); + + } + + public void testAddLimitToSQL() throws Exception { + final SchemaConfig schema = schemaMap.get("TESTDB"); + + String sql = null; + RouteResultset rrs = null; + + sql = "select * from orders"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + Map nodeMap = getNodeMap(rrs, 2); + NodeNameAsserter nameAsserter = new NodeNameAsserter("dn2", + "dn1"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + Assert.assertEquals(schema.getDefaultMaxLimit(), rrs.getLimitSize()); +// Assert.assertEquals("SELECT * FROM orders LIMIT 100", rrs.getNodes()[0].getStatement()); + MySqlStatementParser parser = new MySqlStatementParser("SELECT * FROM orders LIMIT 100"); + SQLStatement statement = parser.parseStatement(); + Assert.assertEquals(statement.toString(), rrs.getNodes()[0].getStatement()); + + + sql = "select * from goods"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(schema.getDefaultMaxLimit(), rrs.getLimitSize()); +// Assert.assertEquals("select * from goods", rrs.getNodes()[0].getStatement()); + parser = new MySqlStatementParser("SELECT * FROM goods LIMIT 100"); + statement = parser.parseStatement(); + Assert.assertEquals(statement.toString(), rrs.getNodes()[0].getStatement()); + + + sql = "select * from goods limit 2 ,3"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(1, rrs.getNodes().length); +// Assert.assertEquals(-1, rrs.getLimitSize()); + Assert.assertEquals("select * from goods limit 2 ,3", rrs.getNodes()[0].getStatement()); + + + sql = "select * from notpartionTable limit 2 ,3"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals(3, rrs.getLimitSize()); + Assert.assertEquals("select * from notpartionTable limit 2 ,3", rrs.getNodes()[0].getStatement()); + + } + + + public void testModifySQLLimit() throws Exception { + final SchemaConfig schema = schemaMap.get("TESTDB"); + + String sql = null; + RouteResultset rrs = null; + //SQL span multi datanode + sql = "select * from orders limit 2,3"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + Map nodeMap = getNodeMap(rrs, 2); + NodeNameAsserter nameAsserter = new NodeNameAsserter("dn2", + "dn1"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + Assert.assertEquals(3, rrs.getLimitSize()); + + + MySqlStatementParser parser = new MySqlStatementParser("SELECT * FROM orders LIMIT 0,5"); + SQLStatement statement = parser.parseStatement(); + + Assert.assertEquals(statement.toString(), rrs.getNodes()[0].getStatement()); + + //SQL not span multi datanode + sql = "select * from customer where id=10000 limit 2,3"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + nodeMap = getNodeMap(rrs, 1); + nameAsserter = new NodeNameAsserter("dn1"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + Assert.assertEquals(3, rrs.getLimitSize()); + Assert.assertEquals("select * from customer where id=10000 limit 2,3", rrs.getNodes()[0].getStatement()); + + + } + + public void testGroupLimit() throws Exception { + final SchemaConfig schema = schemaMap.get("cndb"); + + String sql = null; + RouteResultset rrs = null; + + sql = "select count(*) from (select * from(select * from offer_detail where offer_id='123' or offer_id='234' limit 88)offer where offer.member_id='abc' limit 60) w " + + " where w.member_id ='pavarotti17' limit 99"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + // Assert.assertEquals(88L, rrs.getLimitSize()); + // Assert.assertEquals(RouteResultset.SUM_FLAG, rrs.getFlag()); + Map nodeMap = getNodeMap(rrs, 2); + NodeNameAsserter nameAsserter = new NodeNameAsserter("detail_dn29", + "detail_dn15"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + + sql = "select count(*) from (select * from(select max(id) from offer_detail where offer_id='123' or offer_id='234' limit 88)offer where offer.member_id='abc' limit 60) w " + + " where w.member_id ='pavarotti17' limit 99"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + nodeMap = getNodeMap(rrs, 2); + nameAsserter = new NodeNameAsserter("detail_dn29", "detail_dn15"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + + sql = "select * from (select * from(select max(id) from offer_detail where offer_id='123' or offer_id='234' limit 88)offer where offer.member_id='abc' limit 60) w " + + " where w.member_id ='pavarotti17' limit 99"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + nodeMap = getNodeMap(rrs, 2); + nameAsserter = new NodeNameAsserter("detail_dn29", "detail_dn15"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + + sql = "select * from (select count(*) from(select * from offer_detail where offer_id='123' or offer_id='234' limit 88)offer where offer.member_id='abc' limit 60) w " + + " where w.member_id ='pavarotti17' limit 99"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(true, rrs.isCacheAble()); + // Assert.assertEquals(88L, rrs.getLimitSize()); + // Assert.assertEquals(RouteResultset.SUM_FLAG, rrs.getFlag()); + nodeMap = getNodeMap(rrs, 2); + nameAsserter = new NodeNameAsserter("detail_dn29", "detail_dn15"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + + } + + public void testTableMetaRead() throws Exception { + final SchemaConfig schema = schemaMap.get("cndb"); + + String sql = " desc offer"; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.DESCRIBE, sql, null, null, + cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(1, rrs.getNodes().length); + // random return one node + // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); + Assert.assertEquals("desc offer", rrs.getNodes()[0].getStatement()); + + sql = "desc cndb.offer"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.DESCRIBE, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(1, rrs.getNodes().length); + // random return one node + // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); + Assert.assertEquals("desc offer", rrs.getNodes()[0].getStatement()); + + sql = "desc cndb.offer col1"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.DESCRIBE, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(1, rrs.getNodes().length); + // random return one node + // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); + Assert.assertEquals("desc offer col1", rrs.getNodes()[0].getStatement()); + + sql = "SHOW FULL COLUMNS FROM offer IN db_name WHERE true"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, + cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(1, rrs.getNodes().length); + // random return one node + // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); + Assert.assertEquals("SHOW FULL COLUMNS FROM offer WHERE true", + rrs.getNodes()[0].getStatement()); + + sql = "SHOW FULL COLUMNS FROM db.offer IN db_name WHERE true"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, + cachePool); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(1, rrs.getNodes().length); + // random return one node + // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); + Assert.assertEquals("SHOW FULL COLUMNS FROM offer WHERE true", + rrs.getNodes()[0].getStatement()); + + + sql = "SHOW FULL TABLES FROM `TESTDB` WHERE Table_type != 'VIEW'"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, + cachePool); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals("SHOW FULL TABLES WHERE Table_type != 'VIEW'", rrs.getNodes()[0].getStatement()); + + sql = "SHOW INDEX IN offer FROM db_name"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, + cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(1, rrs.getNodes().length); + // random return one node + // Assert.assertEquals("offer_dn[0]", rrs.getNodes()[0].getName()); + Assert.assertEquals("SHOW INDEX FROM offer", + rrs.getNodes()[0].getStatement()); + + sql = "SHOW TABLES from db_name like 'solo'"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, + cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Map nodeMap = getNodeMap(rrs, 3); + NodeNameAsserter nameAsserter = new NodeNameAsserter("detail_dn0", + "offer_dn0", "independent_dn0"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + SimpleSQLAsserter sqlAsserter = new SimpleSQLAsserter(); + sqlAsserter.addExpectSQL(0, "SHOW TABLES like 'solo'") + .addExpectSQL(1, "SHOW TABLES like 'solo'") + .addExpectSQL(2, "SHOW TABLES like 'solo'") + .addExpectSQL(3, "SHOW TABLES like 'solo'"); + RouteNodeAsserter asserter = new RouteNodeAsserter(nameAsserter, + sqlAsserter); + for (RouteResultsetNode node : nodeMap.values()) { + asserter.assertNode(node); + } + + sql = "SHOW TABLES in db_name "; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, + cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + nodeMap = getNodeMap(rrs, 3); + nameAsserter = new NodeNameAsserter("detail_dn0", "offer_dn0", + "independent_dn0"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + sqlAsserter = new SimpleSQLAsserter(); + sqlAsserter.addExpectSQL(0, "SHOW TABLES") + .addExpectSQL(1, "SHOW TABLES").addExpectSQL(2, "SHOW TABLES") + .addExpectSQL(3, "SHOW TABLES"); + asserter = new RouteNodeAsserter(nameAsserter, sqlAsserter); + for (RouteResultsetNode node : nodeMap.values()) { + asserter.assertNode(node); + } + + sql = "SHOW TABLeS "; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SHOW, sql, null, null, + cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + nodeMap = getNodeMap(rrs, 3); + nameAsserter = new NodeNameAsserter("offer_dn0", "detail_dn0", + "independent_dn0"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + sqlAsserter = new SimpleSQLAsserter(); + sqlAsserter.addExpectSQL(0, "SHOW TABLeS ") + .addExpectSQL(1, "SHOW TABLeS ").addExpectSQL(2, "SHOW TABLeS "); + asserter = new RouteNodeAsserter(nameAsserter, sqlAsserter); + for (RouteResultsetNode node : nodeMap.values()) { + asserter.assertNode(node); + } + } + + public void testConfigSchema() throws Exception { + try { + SchemaConfig schema = schemaMap.get("config"); + String sql = "select * from offer where offer_id=1"; + routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertFalse(true); + } catch (Exception e) { + Assert.assertEquals("route rule for table OFFER is required: select * from offer where offer_id=1", e.getMessage()); + } + try { + SchemaConfig schema = schemaMap.get("config"); + String sql = "select * from offer where col11111=1"; + routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertFalse(true); + } catch (Exception e) { + } + try { + SchemaConfig schema = schemaMap.get("config"); + String sql = "select * from offer "; + routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertFalse(true); + } catch (Exception e) { + } + } + + public void testIgnoreSchema() throws Exception { + SchemaConfig schema = schemaMap.get("ignoreSchemaTest"); + String sql = "select * from offer where offer_id=1"; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, + cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals("cndb_dn", rrs.getNodes()[0].getName()); + Assert.assertEquals(sql, rrs.getNodes()[0].getStatement()); + sql = "select * from ignoreSchemaTest.offer1 where ignoreSchemaTest.offer1.offer_id=1"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals("select * from offer1 where offer1.offer_id=1", + rrs.getNodes()[0].getStatement()); + sql = "select * from ignoreSchemaTest2.offer where ignoreSchemaTest2.offer.offer_id=1"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(sql, rrs.getNodes()[0].getStatement(), sql); + sql = "select * from ignoreSchemaTest2.offer a,offer b where ignoreSchemaTest2.offer.offer_id=1"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals( + "select * from ignoreSchemaTest2.offer a,offer b where ignoreSchemaTest2.offer.offer_id=1", + rrs.getNodes()[0].getStatement()); + + } + + public void testNonPartitionSQL() throws Exception { + + SchemaConfig schema = schemaMap.get("cndb"); + String sql = null; + RouteResultset rrs = null; + + schema = schemaMap.get("dubbo2"); + sql = "SHOW TABLES from db_name like 'solo'"; + rrs = routeStrategy.route(new SystemConfig(), schema, 9, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); + Assert.assertEquals("SHOW TABLES like 'solo'", + rrs.getNodes()[0].getStatement()); + + schema = schemaMap.get("dubbo"); + sql = "SHOW TABLES from db_name like 'solo'"; + rrs = routeStrategy.route(new SystemConfig(), schema, 9, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals("dubbo_dn", rrs.getNodes()[0].getName()); + Assert.assertEquals("SHOW TABLES like 'solo'", + rrs.getNodes()[0].getStatement()); + + + + sql = "desc cndb.offer"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals(-1L, rrs.getLimitSize()); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals("dubbo_dn", rrs.getNodes()[0].getName()); + Assert.assertEquals("desc cndb.offer", rrs.getNodes()[0].getStatement()); + + schema = schemaMap.get("cndb"); + sql = "SHOW fulL TaBLES from db_name like 'solo'"; + rrs = routeStrategy.route(new SystemConfig(), schema, 9, sql, null, null, cachePool); + Assert.assertEquals(false, rrs.isCacheAble()); + Map nodeMap = getNodeMap(rrs, 3); + NodeNameAsserter nameAsserter = new NodeNameAsserter("detail_dn0", + "offer_dn0", "independent_dn0"); + nameAsserter.assertRouteNodeNames(nodeMap.keySet()); + SimpleSQLAsserter sqlAsserter = new SimpleSQLAsserter(); + sqlAsserter.addExpectSQL(0, "SHOW FULL TABLES like 'solo'") + .addExpectSQL(1, "SHOW FULL TABLES like 'solo'") + .addExpectSQL(2, "SHOW FULL TABLES like 'solo'") + .addExpectSQL(3, "SHOW FULL TABLES like 'solo'"); + RouteNodeAsserter asserter = new RouteNodeAsserter(nameAsserter, + sqlAsserter); + for (RouteResultsetNode node : nodeMap.values()) { + asserter.assertNode(node); + } + } + + public void testGlobalTableSingleNodeLimit() throws Exception { + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "select * from globalsn"; + RouteResultset rrs = null; + rrs = routeStrategy.route(new SystemConfig(), schema, + ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(100L, rrs.getLimitSize()); + } + + /** + * select 1 + * select 1 union all select 2 + * + * @throws Exception + */ + public void testSelectNoTable() throws Exception { + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "select 1"; + RouteResultset rrs = null; + rrs = routeStrategy.route(new SystemConfig(), schema, + ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + + + sql = "select 1 union select 2"; + rrs = routeStrategy.route(new SystemConfig(), schema, + ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + } + + /** + * 支持insert into ... values (),()... + * 不支持insert into ... select... + * + * @throws Exception + */ + public void testBatchInsert() throws Exception { + + SchemaConfig schema = schemaMap.get("TESTDB"); + RouteResultset rrs = null; + //不支持childtable 批量插入 + String sql = "insert into orders (id,name,customer_id) values(1,'testonly',1),(2,'testonly',2000001)"; + try { + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, + cachePool); + } catch (Exception e) { + Assert.assertEquals("ChildTable multi insert not provided", e.getMessage()); + } + + sql = "insert into employee (id,name,customer_id) select id,name,customer_id from customer"; + try { + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, + cachePool); + } catch (Exception e) { + Assert.assertEquals("TODO:insert into .... select .... not supported!", e.getMessage()); + } + + //分片表批量插入正常 employee + sql = "insert into employee (id,name,sharding_id) values(1,'testonly',10000),(2,'testonly',10010)"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, + cachePool); + Assert.assertEquals(2, rrs.getNodes().length); + Assert.assertEquals(false, rrs.isCacheAble()); + Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); + Assert.assertEquals("dn2", rrs.getNodes()[1].getName()); + String node1Sql = formatSql("insert into employee (id,name,sharding_id) values(1,'testonly',10000)"); + String node2Sql = formatSql("insert into employee (id,name,sharding_id) values(2,'testonly',10010)"); + RouteResultsetNode[] nodes = rrs.getNodes(); + Assert.assertEquals(node1Sql, nodes[0].getStatement()); + Assert.assertEquals(node2Sql, nodes[1].getStatement()); + } + + /** + * insert ... on duplicate key ... update... + * + * @throws Exception + */ + public void testInsertOnDuplicateKey() throws Exception { + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "insert into employee (id,name,sharding_id) values(1,'testonly',10000) on duplicate key update name='nihao'"; + RouteResultset rrs = null; + rrs = routeStrategy.route(new SystemConfig(), schema, + ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); + + //insert ... on duplicate key ... update col1 = VALUES(col1),col2 = VALUES(col2) + sql = "insert into employee (id,name,sharding_id) values(1,'testonly',10000) " + + "on duplicate key update name=VALUES(name),id = VALUES(id)"; + rrs = routeStrategy.route(new SystemConfig(), schema, + ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + Assert.assertEquals("dn1", rrs.getNodes()[0].getName()); + + //insert ... on duplicate key ,sharding key can't be updated + sql = "insert into employee (id,name,sharding_id) values(1,'testonly',10000) " + + "on duplicate key update name=VALUES(name),id = VALUES(id),sharding_id = VALUES(sharding_id)"; + + try { + rrs = routeStrategy.route(new SystemConfig(), schema, + ServerParse.SELECT, sql, null, null, cachePool); + } catch (Exception e) { + Assert.assertEquals("Sharding column can't be updated: EMPLOYEE -> SHARDING_ID", e.getMessage()); + } + + + } + + /** + * 测试函数COUNT + * + * @throws Exception + */ + @Test + public void testAggregateExpr() throws Exception { + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "select id, name, count(name) from employee group by name;"; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getMergeCols().containsKey("COUNT2")); + + sql = "select id, name, count(name) as c from employee group by name;"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getMergeCols().containsKey("c")); + + sql = "select id, name, count(name) c from employee group by name;"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getMergeCols().containsKey("c")); + } + + /** + * 测试between语句的路由 + * + * @throws Exception + */ + @Test + public void testBetweenExpr() throws Exception { +// 0-200M=0 +// 200M1-400M=1 +// 400M1-600M=2 +// 600M1-800M=3 +// 800M1-1000M=4 + + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "select * from customer where id between 1 and 5;"; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes().length == 1); + Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); + + sql = "select * from customer where id between 1 and 2000001;"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes().length == 2); + + sql = "select * from customer where id between 2000001 and 3000001;"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes().length == 1); + Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn2")); + + sql = "delete from customer where id between 2000001 and 3000001;"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes().length == 1); + Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn2")); + + sql = "update customer set name='newName' where id between 2000001 and 3000001;"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes().length == 1); + Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn2")); + + } + + /** + * 测试or语句的路由 + * + * @throws Exception + */ + @Test + public void testOr() throws Exception { +// 0-200M=0 +// 200M1-400M=1 +// 400M1-600M=2 +// 600M1-800M=3 +// 800M1-1000M=4 + + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "select * from customer where sharding_id=10000 or 1=1;"; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes().length == 2); + Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); + Assert.assertTrue(rrs.getNodes()[1].getName().equals("dn2")); + + sql = "select * from customer where sharding_id = 10000 or sharding_id = 10010"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); + Assert.assertTrue(rrs.getNodes()[1].getName().equals("dn2")); + + sql = "select * from customer where sharding_id = 10000 or user_id = 'wangwu'"; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); + Assert.assertTrue(rrs.getNodes()[1].getName().equals("dn2")); + } + + /** + * 测试父子表,查询子表的语句路由到多个节点 + * @throws Exception + */ + @Test + public void testERRouteMutiNode() throws Exception { + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "select * from orders where customer_id in(1,2000001);"; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes().length == 2); + Assert.assertTrue(rrs.getNodes()[0].getName().equals("dn1")); + Assert.assertTrue(rrs.getNodes()[1].getName().equals("dn2")); + } + + /** + * 测试多层or语句 + * + * @throws Exception + */ + @Test + public void testMultiLevelOr() throws Exception { + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "select id from travelrecord " + + " where id = 1 and ( fee=3 or days=5 or (traveldate = '2015-05-04 00:00:07.375' " + + " and (user_id=2 or fee=days or fee = 0))) and name = 'zhangsan'" ; + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + + Assert.assertTrue(rrs.getNodes().length == 1); + + sql = "select id from travelrecord " + + " where id = 1 and ( fee=3 or days=5 or (traveldate = '2015-05-04 00:00:07.375' " + + " and (user_id=2 or fee=days or fee = 0))) and name = 'zhangsan' or id = 2000001" ; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + + Assert.assertTrue(rrs.getNodes().length == 2); + + sql = "select id from travelrecord " + + " where id = 1 and ( fee=3 or days=5 or (traveldate = '2015-05-04 00:00:07.375' " + + " and (user_id=2 or fee=days or fee = 0))) and name = 'zhangsan' or id = 2000001 or id = 4000001" ; + rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + + Assert.assertTrue(rrs.getNodes().length == 3); + } + + /** + * 测试 global table 的or语句 + * + * + * @throws Exception + */ + @Test + public void testGlobalTableOr() throws Exception { + SchemaConfig schema = schemaMap.get("TESTDB"); + String sql = "select id from company where 1 = 1 and name ='company1' or name = 'company2'" ; + for(int i = 0; i < 20; i++) { + RouteResultset rrs = routeStrategy.route(new SystemConfig(), schema, ServerParse.SELECT, sql, null, null, cachePool); + Assert.assertTrue(rrs.getNodes().length == 1); + } + } + + /** + * 测试别名路由 + * + * @throws Exception + */ + public void testAlias() throws Exception { + + SchemaConfig schema = schemaMap.get("TESTDB"); + RouteResultset rrs = null; + //不支持childtable 批量插入 + //update 全局表 + String sql = "update company a set name = '' where a.id = 1;"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, + cachePool); + + Assert.assertEquals(3, rrs.getNodes().length); + + //update带别名时的路由 + sql = "update travelrecord a set name = '' where a.id = 1;"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, + cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + + //别名大小写路由 + sql = "select * from travelrecord A where a.id = 1;"; + rrs = routeStrategy.route(new SystemConfig(), schema, 1, sql, null, null, + cachePool); + Assert.assertEquals(1, rrs.getNodes().length); + } + + private String formatSql(String sql) { + MySqlStatementParser parser = new MySqlStatementParser(sql); + SQLStatement stmt = parser.parseStatement(); + return stmt.toString(); + } + + +}