Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][mysql-service] Add insert id to mysql protocol OkPacket #856

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2021 DataCanvas
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.dingodb.calcite.operation;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class ShowLastInsertIdOperation implements QueryOperation {
Connection connection;

public ShowLastInsertIdOperation(Connection connection) {
this.connection = connection;
}

@Override
public Iterator getIterator() {
try {
List<Object[]> variableValList = new ArrayList<>();
String value = connection.getClientInfo().getProperty("last_insert_id", "0");
variableValList.add(new Object[]{value});
return variableValList.iterator();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

@Override
public List<String> columns() {
List<String> columns = new ArrayList<>();
columns.add("last_insert_id()");
return columns;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright 2021 DataCanvas
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.dingodb.calcite.rel;

import io.dingodb.calcite.DingoTable;
import io.dingodb.common.table.ColumnDefinition;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCalc;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMatch;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalTableModify;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;

import java.util.ArrayDeque;
import java.util.Deque;

public class AutoIncrementShuttle implements RelShuttle {

public static AutoIncrementShuttle INSTANCE = new AutoIncrementShuttle();

protected final Deque<RelNode> stack = new ArrayDeque<>();

@Override
public RelNode visit(TableScan scan) {
return null;
}

@Override
public RelNode visit(TableFunctionScan scan) {
return null;
}

@Override
public RelNode visit(LogicalValues values) {
return null;
}

@Override
public RelNode visit(LogicalFilter filter) {
return null;
}

@Override
public RelNode visit(LogicalCalc calc) {
return null;
}

@Override
public RelNode visit(LogicalProject project) {
return null;
}

@Override
public RelNode visit(LogicalJoin join) {
return null;
}

@Override
public RelNode visit(LogicalCorrelate correlate) {
return null;
}

@Override
public RelNode visit(LogicalUnion union) {
return null;
}

@Override
public RelNode visit(LogicalIntersect intersect) {
return null;
}

@Override
public RelNode visit(LogicalMinus minus) {
return null;
}

@Override
public RelNode visit(LogicalAggregate aggregate) {
return null;
}

@Override
public RelNode visit(LogicalMatch match) {
return null;
}

@Override
public RelNode visit(LogicalSort sort) {
return null;
}

@Override
public RelNode visit(LogicalExchange exchange) {
return null;
}

@Override
public RelNode visit(LogicalTableModify modify) {
return null;
}

@Override
public RelNode visit(RelNode other) {
if (other instanceof DingoTableModify) {
DingoTableModify modify = (DingoTableModify) other;
if (modify.isInsert()) {
DingoTable table = modify.getTable().unwrap(DingoTable.class);
boolean hasAutoIncrement = false;
for (ColumnDefinition columnDefinition : table.getTableDefinition().getColumns()) {
if (columnDefinition.isAutoIncrement()) {
hasAutoIncrement = true;
}
}
if (hasAutoIncrement && other.getInputs().size() > 0) {
RelNode values = visitChildren(other);
if (values instanceof DingoValues) {
DingoValues dingoValues = (DingoValues) values;
dingoValues.setHasAutoIncrement(true);
return dingoValues;
}
}
}
return null;
} else if (other instanceof DingoValues) {
return other;
} else {
if (other.getInputs().size() > 0) {
return visitChildren(other);
} else {
return null;
}
}
}

protected RelNode visitChildren(RelNode rel) {
for (Ord<RelNode> input : Ord.zip(rel.getInputs())) {
rel = visitChild(rel, input.e);
}
return rel;
}

protected RelNode visitChild(RelNode parent, RelNode child) {
stack.push(parent);
try {
RelNode child2 = child.accept(this);
if (child2 instanceof DingoValues) {
return child2;
}
return null;
} finally {
stack.pop();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package io.dingodb.calcite.rel;

import io.dingodb.calcite.visitor.DingoRelVisitor;
import lombok.Getter;
import lombok.Setter;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
Expand All @@ -30,6 +32,10 @@
import java.util.List;

public class DingoValues extends LogicalDingoValues implements DingoRel {
@Setter
@Getter
private boolean hasAutoIncrement;

public DingoValues(
RelOptCluster cluster,
RelTraitSet traits,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.dingodb.calcite.type.converter.DefinitionMapper;
import io.dingodb.common.type.DingoType;
import lombok.Getter;
import lombok.Setter;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.AbstractRelNode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class ScopeVariables {
immutableVariables.add("have_openssl");
immutableVariables.add("have_ssl");
immutableVariables.add("have_statement_timeout");
immutableVariables.add("last_insert_id");

characterSet.add("utf8mb4");
characterSet.add("utf8");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.dingodb.calcite.operation.DdlOperation;
import io.dingodb.calcite.operation.Operation;
import io.dingodb.calcite.operation.QueryOperation;
import io.dingodb.calcite.rel.AutoIncrementShuttle;
import io.dingodb.calcite.rel.DingoValues;
import io.dingodb.calcite.type.converter.DefinitionMapper;
import io.dingodb.calcite.visitor.DingoJobVisitor;
import io.dingodb.common.Location;
Expand Down Expand Up @@ -57,6 +59,7 @@
import org.checkerframework.checker.nullness.qual.Nullable;

import java.sql.DatabaseMetaData;
import java.sql.SQLClientInfoException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -269,6 +272,7 @@ public Meta.Signature parseQuery(

final RelRoot relRoot = convert(sqlNode, false);
final RelNode relNode = optimize(relRoot.rel);
extractAutoIncrement(relNode, jobIdPrefix);
Location currentLocation = MetaService.root().currentLocation();
RelDataType parasType = validator.getParameterRowType(sqlNode);
// get start_ts for jobSeqId, if transaction is not null ,transaction start_ts is jobDomainId
Expand Down Expand Up @@ -302,4 +306,26 @@ public Meta.Signature parseQuery(
job.getJobId()
);
}

/**
* Determine if it is an insert statement and if there is an autoincrement primary key in the table.
* @param relNode dingo relNode
* @param jobIdPrefix Used to distinguish between different SQL statements in the same session
*/
private void extractAutoIncrement(RelNode relNode, String jobIdPrefix) {
try {
RelNode relVal = relNode.accept(AutoIncrementShuttle.INSTANCE);
if (relVal instanceof DingoValues) {
DingoValues dingoValues = (DingoValues) relVal;
if (!dingoValues.isHasAutoIncrement()) {
return;
}
Object autoValue = dingoValues.getTuples().get(0)[0];
connection.setClientInfo("last_insert_id", autoValue.toString());
connection.setClientInfo(jobIdPrefix, autoValue.toString());
}
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.dingodb.common.mysql.constant.ErrorCode;
import io.dingodb.driver.DingoConnection;
import io.dingodb.driver.DingoPreparedStatement;
import io.dingodb.driver.DingoStatement;
import io.dingodb.driver.mysql.MysqlConnection;
import io.dingodb.driver.mysql.MysqlType;
import io.dingodb.driver.mysql.packet.ColumnPacket;
Expand All @@ -31,6 +32,7 @@
import io.dingodb.driver.mysql.packet.QueryPacket;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.avatica.Meta;
import org.apache.commons.lang3.StringUtils;

import java.io.UnsupportedEncodingException;
import java.math.BigDecimal;
Expand Down Expand Up @@ -175,8 +177,19 @@ public void executeSingleQuery(String sql, AtomicLong packetId,
} else {
// update insert delete
int count = statement.getUpdateCount();
OKPacket okPacket = MysqlPacketFactory.getInstance()
.getOkPacket(count, packetId);
DingoStatement dingoStatement = (DingoStatement) statement;
String jobIdPrefix = dingoStatement.handle.toString();
OKPacket okPacket;
if (mysqlConnection.getConnection().getClientInfo().containsKey(jobIdPrefix)) {
String lastInsertId = mysqlConnection.getConnection()
.getClientInfo().getProperty(jobIdPrefix, "0");
okPacket = MysqlPacketFactory.getInstance()
.getOkPacket(count, packetId, 0, Integer.parseInt(lastInsertId));
mysqlConnection.getConnection().getClientInfo().remove(jobIdPrefix);
} else {
okPacket = MysqlPacketFactory.getInstance()
.getOkPacket(count, packetId);
}
MysqlResponseHandler.responseOk(okPacket, mysqlConnection.channel);
}
} catch (SQLException sqlException) {
Expand Down Expand Up @@ -302,7 +315,17 @@ public void executeStatement(ExecuteStatementPacket statementPacket,
}
} else {
int affected = preparedStatement.executeUpdate();
OKPacket okPacket = mysqlPacketFactory.getOkPacket(affected, packetId);
String jobIdPrefix = preparedStatement.handle.toString();
OKPacket okPacket;
if (mysqlConnection.getConnection().getClientInfo().containsKey(jobIdPrefix)) {
String lastInsertId = mysqlConnection.getConnection()
.getClientInfo().getProperty(jobIdPrefix, "0");
okPacket = MysqlPacketFactory.getInstance()
.getOkPacket(affected, packetId, 0, Integer.parseInt(lastInsertId));
mysqlConnection.getConnection().getClientInfo().remove(jobIdPrefix);
} else {
okPacket = mysqlPacketFactory.getOkPacket(affected, packetId);
}
MysqlResponseHandler.responseOk(okPacket, mysqlConnection.channel);
}
} catch (SQLException e) {
Expand Down
Loading