From c46fa20ef47a22f93364803a352095027a18b224 Mon Sep 17 00:00:00 2001 From: guojn1 Date: Tue, 21 Nov 2023 17:56:02 +0800 Subject: [PATCH] [fix][mysql-service] Add autoincrement id to mysql protocol OkPacket --- .../operation/ShowLastInsertIdOperation.java | 50 +++++ .../calcite/rel/AutoIncrementShuttle.java | 182 ++++++++++++++++++ .../io/dingodb/calcite/rel/DingoValues.java | 6 + .../calcite/rel/LogicalDingoValues.java | 1 + .../common/mysql/scope/ScopeVariables.java | 1 + .../io/dingodb/driver/DingoDriverParser.java | 26 +++ .../driver/mysql/command/MysqlCommands.java | 29 ++- .../mysql/command/MysqlResponseHandler.java | 21 +- .../mysql/packet/ExecuteStatementPacket.java | 3 + .../mysql/packet/MysqlPacketFactory.java | 22 ++- 10 files changed, 316 insertions(+), 25 deletions(-) create mode 100644 dingo-calcite/src/main/java/io/dingodb/calcite/operation/ShowLastInsertIdOperation.java create mode 100644 dingo-calcite/src/main/java/io/dingodb/calcite/rel/AutoIncrementShuttle.java diff --git a/dingo-calcite/src/main/java/io/dingodb/calcite/operation/ShowLastInsertIdOperation.java b/dingo-calcite/src/main/java/io/dingodb/calcite/operation/ShowLastInsertIdOperation.java new file mode 100644 index 0000000000..a4a3dfa212 --- /dev/null +++ b/dingo-calcite/src/main/java/io/dingodb/calcite/operation/ShowLastInsertIdOperation.java @@ -0,0 +1,50 @@ +/* + * Copyright 2021 DataCanvas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.dingodb.calcite.operation; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +public class ShowLastInsertIdOperation implements QueryOperation { + Connection connection; + + public ShowLastInsertIdOperation(Connection connection) { + this.connection = connection; + } + + @Override + public Iterator getIterator() { + try { + List variableValList = new ArrayList<>(); + String value = connection.getClientInfo().getProperty("last_insert_id", "0"); + variableValList.add(new Object[]{value}); + return variableValList.iterator(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + @Override + public List columns() { + List columns = new ArrayList<>(); + columns.add("last_insert_id()"); + return columns; + } +} diff --git a/dingo-calcite/src/main/java/io/dingodb/calcite/rel/AutoIncrementShuttle.java b/dingo-calcite/src/main/java/io/dingodb/calcite/rel/AutoIncrementShuttle.java new file mode 100644 index 0000000000..a3ad220bd4 --- /dev/null +++ b/dingo-calcite/src/main/java/io/dingodb/calcite/rel/AutoIncrementShuttle.java @@ -0,0 +1,182 @@ +/* + * Copyright 2021 DataCanvas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.dingodb.calcite.rel; + +import io.dingodb.calcite.DingoTable; +import io.dingodb.common.table.ColumnDefinition; +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelShuttle; +import org.apache.calcite.rel.core.TableFunctionScan; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.logical.LogicalCalc; +import org.apache.calcite.rel.logical.LogicalCorrelate; +import org.apache.calcite.rel.logical.LogicalExchange; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalIntersect; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalMatch; +import org.apache.calcite.rel.logical.LogicalMinus; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalSort; +import org.apache.calcite.rel.logical.LogicalTableModify; +import org.apache.calcite.rel.logical.LogicalUnion; +import org.apache.calcite.rel.logical.LogicalValues; + +import java.util.ArrayDeque; +import java.util.Deque; + +public class AutoIncrementShuttle implements RelShuttle { + + public static AutoIncrementShuttle INSTANCE = new AutoIncrementShuttle(); + + protected final Deque stack = new ArrayDeque<>(); + + @Override + public RelNode visit(TableScan scan) { + return null; + } + + @Override + public RelNode visit(TableFunctionScan scan) { + return null; + } + + @Override + public RelNode visit(LogicalValues values) { + return null; + } + + @Override + public RelNode visit(LogicalFilter filter) { + return null; + } + + @Override + public RelNode visit(LogicalCalc calc) { + return null; + } + + @Override + public RelNode visit(LogicalProject project) { + return null; + } + + @Override + public RelNode visit(LogicalJoin join) { + return null; + } + + @Override + public RelNode visit(LogicalCorrelate correlate) { + return null; + } + + @Override + public RelNode visit(LogicalUnion union) { + return null; + } + + @Override + public RelNode visit(LogicalIntersect intersect) { + return null; + } + + @Override + public RelNode visit(LogicalMinus minus) { + return null; + } + + @Override + public RelNode visit(LogicalAggregate aggregate) { + return null; + } + + @Override + public RelNode visit(LogicalMatch match) { + return null; + } + + @Override + public RelNode visit(LogicalSort sort) { + return null; + } + + @Override + public RelNode visit(LogicalExchange exchange) { + return null; + } + + @Override + public RelNode visit(LogicalTableModify modify) { + return null; + } + + @Override + public RelNode visit(RelNode other) { + if (other instanceof DingoTableModify) { + DingoTableModify modify = (DingoTableModify) other; + if (modify.isInsert()) { + DingoTable table = modify.getTable().unwrap(DingoTable.class); + boolean hasAutoIncrement = false; + for (ColumnDefinition columnDefinition : table.getTableDefinition().getColumns()) { + if (columnDefinition.isAutoIncrement()) { + hasAutoIncrement = true; + } + } + if (hasAutoIncrement && other.getInputs().size() > 0) { + RelNode values = visitChildren(other); + if (values instanceof DingoValues) { + DingoValues dingoValues = (DingoValues) values; + dingoValues.setHasAutoIncrement(true); + return dingoValues; + } + } + } + return null; + } else if (other instanceof DingoValues) { + return other; + } else { + if (other.getInputs().size() > 0) { + return visitChildren(other); + } else { + return null; + } + } + } + + protected RelNode visitChildren(RelNode rel) { + for (Ord input : Ord.zip(rel.getInputs())) { + rel = visitChild(rel, input.e); + } + return rel; + } + + protected RelNode visitChild(RelNode parent, RelNode child) { + stack.push(parent); + try { + RelNode child2 = child.accept(this); + if (child2 instanceof DingoValues) { + return child2; + } + return null; + } finally { + stack.pop(); + } + } +} diff --git a/dingo-calcite/src/main/java/io/dingodb/calcite/rel/DingoValues.java b/dingo-calcite/src/main/java/io/dingodb/calcite/rel/DingoValues.java index 34b72a5cb6..9a3188bedb 100644 --- a/dingo-calcite/src/main/java/io/dingodb/calcite/rel/DingoValues.java +++ b/dingo-calcite/src/main/java/io/dingodb/calcite/rel/DingoValues.java @@ -17,6 +17,8 @@ package io.dingodb.calcite.rel; import io.dingodb.calcite.visitor.DingoRelVisitor; +import lombok.Getter; +import lombok.Setter; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; @@ -30,6 +32,10 @@ import java.util.List; public class DingoValues extends LogicalDingoValues implements DingoRel { + @Setter + @Getter + private boolean hasAutoIncrement; + public DingoValues( RelOptCluster cluster, RelTraitSet traits, diff --git a/dingo-calcite/src/main/java/io/dingodb/calcite/rel/LogicalDingoValues.java b/dingo-calcite/src/main/java/io/dingodb/calcite/rel/LogicalDingoValues.java index 2b73f9539c..5cbb4d2b8c 100644 --- a/dingo-calcite/src/main/java/io/dingodb/calcite/rel/LogicalDingoValues.java +++ b/dingo-calcite/src/main/java/io/dingodb/calcite/rel/LogicalDingoValues.java @@ -19,6 +19,7 @@ import io.dingodb.calcite.type.converter.DefinitionMapper; import io.dingodb.common.type.DingoType; import lombok.Getter; +import lombok.Setter; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.AbstractRelNode; diff --git a/dingo-common/src/main/java/io/dingodb/common/mysql/scope/ScopeVariables.java b/dingo-common/src/main/java/io/dingodb/common/mysql/scope/ScopeVariables.java index 61fcc21ad6..be649c12b1 100644 --- a/dingo-common/src/main/java/io/dingodb/common/mysql/scope/ScopeVariables.java +++ b/dingo-common/src/main/java/io/dingodb/common/mysql/scope/ScopeVariables.java @@ -44,6 +44,7 @@ public class ScopeVariables { immutableVariables.add("have_openssl"); immutableVariables.add("have_ssl"); immutableVariables.add("have_statement_timeout"); + immutableVariables.add("last_insert_id"); characterSet.add("utf8mb4"); characterSet.add("utf8"); diff --git a/dingo-driver/host/src/main/java/io/dingodb/driver/DingoDriverParser.java b/dingo-driver/host/src/main/java/io/dingodb/driver/DingoDriverParser.java index da427b3c36..c7210a8862 100644 --- a/dingo-driver/host/src/main/java/io/dingodb/driver/DingoDriverParser.java +++ b/dingo-driver/host/src/main/java/io/dingodb/driver/DingoDriverParser.java @@ -22,6 +22,8 @@ import io.dingodb.calcite.operation.DdlOperation; import io.dingodb.calcite.operation.Operation; import io.dingodb.calcite.operation.QueryOperation; +import io.dingodb.calcite.rel.AutoIncrementShuttle; +import io.dingodb.calcite.rel.DingoValues; import io.dingodb.calcite.type.converter.DefinitionMapper; import io.dingodb.calcite.visitor.DingoJobVisitor; import io.dingodb.common.Location; @@ -57,6 +59,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import java.sql.DatabaseMetaData; +import java.sql.SQLClientInfoException; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -269,6 +272,7 @@ public Meta.Signature parseQuery( final RelRoot relRoot = convert(sqlNode, false); final RelNode relNode = optimize(relRoot.rel); + extractAutoIncrement(relNode, jobIdPrefix); Location currentLocation = MetaService.root().currentLocation(); RelDataType parasType = validator.getParameterRowType(sqlNode); // get start_ts for jobSeqId, if transaction is not null ,transaction start_ts is jobDomainId @@ -302,4 +306,26 @@ public Meta.Signature parseQuery( job.getJobId() ); } + + /** + * Determine if it is an insert statement and if there is an autoincrement primary key in the table. + * @param relNode dingo relNode + * @param jobIdPrefix Used to distinguish between different SQL statements in the same session + */ + private void extractAutoIncrement(RelNode relNode, String jobIdPrefix) { + try { + RelNode relVal = relNode.accept(AutoIncrementShuttle.INSTANCE); + if (relVal instanceof DingoValues) { + DingoValues dingoValues = (DingoValues) relVal; + if (!dingoValues.isHasAutoIncrement()) { + return; + } + Object autoValue = dingoValues.getTuples().get(0)[0]; + connection.setClientInfo("last_insert_id", autoValue.toString()); + connection.setClientInfo(jobIdPrefix, autoValue.toString()); + } + } catch (Exception e) { + log.error(e.getMessage(), e); + } + } } diff --git a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlCommands.java b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlCommands.java index 2e2a5052b4..5407ec8166 100644 --- a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlCommands.java +++ b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlCommands.java @@ -20,6 +20,7 @@ import io.dingodb.common.mysql.constant.ErrorCode; import io.dingodb.driver.DingoConnection; import io.dingodb.driver.DingoPreparedStatement; +import io.dingodb.driver.DingoStatement; import io.dingodb.driver.mysql.MysqlConnection; import io.dingodb.driver.mysql.MysqlType; import io.dingodb.driver.mysql.packet.ColumnPacket; @@ -31,6 +32,7 @@ import io.dingodb.driver.mysql.packet.QueryPacket; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.avatica.Meta; +import org.apache.commons.lang3.StringUtils; import java.io.UnsupportedEncodingException; import java.math.BigDecimal; @@ -175,8 +177,19 @@ public void executeSingleQuery(String sql, AtomicLong packetId, } else { // update insert delete int count = statement.getUpdateCount(); - OKPacket okPacket = MysqlPacketFactory.getInstance() - .getOkPacket(count, packetId); + DingoStatement dingoStatement = (DingoStatement) statement; + String jobIdPrefix = dingoStatement.handle.toString(); + OKPacket okPacket; + if (mysqlConnection.getConnection().getClientInfo().containsKey(jobIdPrefix)) { + String lastInsertId = mysqlConnection.getConnection() + .getClientInfo().getProperty(jobIdPrefix, "0"); + okPacket = MysqlPacketFactory.getInstance() + .getOkPacket(count, packetId, 0, Integer.parseInt(lastInsertId)); + mysqlConnection.getConnection().getClientInfo().remove(jobIdPrefix); + } else { + okPacket = MysqlPacketFactory.getInstance() + .getOkPacket(count, packetId); + } MysqlResponseHandler.responseOk(okPacket, mysqlConnection.channel); } } catch (SQLException sqlException) { @@ -302,7 +315,17 @@ public void executeStatement(ExecuteStatementPacket statementPacket, } } else { int affected = preparedStatement.executeUpdate(); - OKPacket okPacket = mysqlPacketFactory.getOkPacket(affected, packetId); + String jobIdPrefix = preparedStatement.handle.toString(); + OKPacket okPacket; + if (mysqlConnection.getConnection().getClientInfo().containsKey(jobIdPrefix)) { + String lastInsertId = mysqlConnection.getConnection() + .getClientInfo().getProperty(jobIdPrefix, "0"); + okPacket = MysqlPacketFactory.getInstance() + .getOkPacket(affected, packetId, 0, Integer.parseInt(lastInsertId)); + mysqlConnection.getConnection().getClientInfo().remove(jobIdPrefix); + } else { + okPacket = mysqlPacketFactory.getOkPacket(affected, packetId); + } MysqlResponseHandler.responseOk(okPacket, mysqlConnection.channel); } } catch (SQLException e) { diff --git a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlResponseHandler.java b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlResponseHandler.java index 117bae94b1..6390c525c0 100644 --- a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlResponseHandler.java +++ b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlResponseHandler.java @@ -37,24 +37,19 @@ import org.apache.commons.lang3.StringUtils; import java.lang.reflect.Array; -import java.sql.Date; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; -import java.sql.Time; -import java.sql.Timestamp; import java.util.ArrayList; import java.util.List; -import java.util.TimeZone; import java.util.concurrent.atomic.AtomicLong; -import java.util.stream.Collectors; import static io.dingodb.calcite.operation.SetOptionOperation.CONNECTION_CHARSET; import static io.dingodb.common.util.Utils.getCharacterSet; import static io.dingodb.common.util.Utils.getDateByTimezone; @Slf4j -public class MysqlResponseHandler { +public final class MysqlResponseHandler { static MysqlPacketFactory factory = MysqlPacketFactory.getInstance(); @@ -94,11 +89,8 @@ public static void responseResultSet(ResultSet resultSet, // 3. eof packet // 4. rows packet // 5. eof packet - boolean deprecateEof = false; - if ((mysqlConnection.authPacket.extendClientFlags - & ExtendedClientCapabilities.CLIENT_DEPRECATE_EOF) != 0) { - deprecateEof = true; - } + boolean deprecateEof = (mysqlConnection.authPacket.extendClientFlags + & ExtendedClientCapabilities.CLIENT_DEPRECATE_EOF) != 0; try { ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); ResultSetMetaData metaData = resultSet.getMetaData(); @@ -281,11 +273,8 @@ public static void responsePrepareExecute(ResultSet resultSet, // 3. eof packet // 4. rows packet // 5. eof packet - boolean deprecateEof = false; - if ((mysqlConnection.authPacket.extendClientFlags - & ExtendedClientCapabilities.CLIENT_DEPRECATE_EOF) != 0) { - deprecateEof = true; - } + boolean deprecateEof = (mysqlConnection.authPacket.extendClientFlags + & ExtendedClientCapabilities.CLIENT_DEPRECATE_EOF) != 0; try { ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); ResultSetMetaData metaData = resultSet.getMetaData(); diff --git a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/ExecuteStatementPacket.java b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/ExecuteStatementPacket.java index fb772f84d2..779b6bcafc 100644 --- a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/ExecuteStatementPacket.java +++ b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/ExecuteStatementPacket.java @@ -71,6 +71,9 @@ public void read(byte[] data) { nullBitmapBuilder.append(BufferUtil.getBinaryStrFromByte(message.read())); } nullBitMap = nullBitmapBuilder.toString(); + if (!message.hasRemaining()) { + return; + } newParamBoundFlag = message.read(); Integer[] types = new Integer[paramCount]; diff --git a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/MysqlPacketFactory.java b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/MysqlPacketFactory.java index 8fe300b6f3..d139a81c68 100644 --- a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/MysqlPacketFactory.java +++ b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/MysqlPacketFactory.java @@ -55,23 +55,33 @@ public static MysqlPacketFactory getInstance() { */ @NonNull public OKPacket getOkEofPacket(int affected, AtomicLong packetId, int serverStatus) { - OKPacket okPacket = newOkPacket(affected, packetId, serverStatus); + OKPacket okPacket = newOkPacket(affected, packetId, serverStatus, 0); okPacket.header = (byte) NativeConstants.TYPE_ID_EOF; return okPacket; } public OKPacket getOkPacket(int affected, AtomicLong packetId) { - return getOkPacket(affected, packetId, 0); + return getOkPacket(affected, packetId, 0, 0); } - @NonNull public OKPacket getOkPacket(int affected, AtomicLong packetId, int serverStatus) { - OKPacket okPacket = newOkPacket(affected, packetId, serverStatus); + return getOkPacket(affected, packetId, serverStatus, 0); + } + + @NonNull + public OKPacket getOkPacket(int affected, + AtomicLong packetId, + int serverStatus, + int lastInsertId) { + OKPacket okPacket = newOkPacket(affected, packetId, serverStatus, lastInsertId); okPacket.header = NativeConstants.TYPE_ID_OK; return okPacket; } - private OKPacket newOkPacket(int affected, AtomicLong packetId, int serverStatus) { + private OKPacket newOkPacket(int affected, + AtomicLong packetId, + int serverStatus, + int lastInsertId) { OKPacket okPacket = new OKPacket(); okPacket.capabilities = MysqlServer.getServerCapabilities(); okPacket.affectedRows = affected; @@ -81,7 +91,7 @@ private OKPacket newOkPacket(int affected, AtomicLong packetId, int serverStatus status |= serverStatus; } okPacket.serverStatus = status; - okPacket.insertId = 0; + okPacket.insertId = lastInsertId; return okPacket; }