Skip to content
Snippets Groups Projects
Unverified Commit d08c0e91 authored by JerryYin's avatar JerryYin Committed by GitHub
Browse files

bugfix: forbidding execute SQL which update pk value (#3129)

parent 1a150bd3
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ import java.sql.Statement;
import java.util.List;
import java.util.concurrent.Callable;
import io.seata.common.exception.NotSupportYetException;
import io.seata.common.exception.ShouldNeverHappenException;
import io.seata.rm.datasource.AbstractConnectionProxy;
import io.seata.rm.datasource.ConnectionContext;
import io.seata.rm.datasource.ConnectionProxy;
......@@ -27,6 +28,7 @@ import io.seata.rm.datasource.StatementProxy;
import io.seata.rm.datasource.sql.struct.TableRecords;
import io.seata.sqlparser.SQLRecognizer;
import io.seata.sqlparser.util.JdbcConstants;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -61,7 +63,7 @@ public abstract class AbstractDMLBaseExecutor<T, S extends Statement> extends Ba
*
* @param statementProxy the statement proxy
* @param statementCallback the statement callback
* @param sqlRecognizers the multi sql recognizer
* @param sqlRecognizers the multi sql recognizer
*/
public AbstractDMLBaseExecutor(StatementProxy<S> statementProxy, StatementCallback<T, S> statementCallback,
List<SQLRecognizer> sqlRecognizers) {
......@@ -86,8 +88,7 @@ public abstract class AbstractDMLBaseExecutor<T, S extends Statement> extends Ba
* @throws Exception the exception
*/
protected T executeAutoCommitFalse(Object[] args) throws Exception {
if (!JdbcConstants.MYSQL.equalsIgnoreCase(getDbType()) && getTableMeta().getPrimaryKeyOnlyName().size() > 1)
{
if (!JdbcConstants.MYSQL.equalsIgnoreCase(getDbType()) && getTableMeta().getPrimaryKeyOnlyName().size() > 1) {
throw new NotSupportYetException("multi pk only support mysql!");
}
TableRecords beforeImage = beforeImage();
......@@ -172,4 +173,13 @@ public abstract class AbstractDMLBaseExecutor<T, S extends Statement> extends Ba
return LOCK_RETRY_POLICY_BRANCH_ROLLBACK_ON_CONFLICT;
}
}
protected void assertContainsPKColumnName(List<String> updateColumns) {
for (String columnName : updateColumns) {
String standardColumnName = getStandardPkColumnName(columnName);
if (StringUtils.isNotEmpty(standardColumnName)) {
throw new ShouldNeverHappenException("Sorry, update pk value is not supported!");
}
}
}
}
......@@ -73,7 +73,7 @@ public abstract class BaseInsertExecutor<T, S extends Statement> extends Abstrac
@Override
protected TableRecords afterImage(TableRecords beforeImage) throws SQLException {
Map<String,List<Object>> pkValues = getPkValues();
Map<String, List<Object>> pkValues = getPkValues();
TableRecords afterImage = buildTableRecords(pkValues);
if (afterImage == null) {
throw new SQLException("Failed to build after-image for insert");
......@@ -102,8 +102,8 @@ public abstract class BaseInsertExecutor<T, S extends Statement> extends Abstrac
* get pk index
* @return the key is pk column name and the value is index of the pk column
*/
protected Map<String,Integer> getPkIndex() {
Map<String,Integer> pkIndexMap = new HashMap<>();
protected Map<String, Integer> getPkIndex() {
Map<String, Integer> pkIndexMap = new HashMap<>();
SQLInsertRecognizer recognizer = (SQLInsertRecognizer) sqlRecognizer;
List<String> insertColumns = recognizer.getInsertColumns();
if (CollectionUtils.isNotEmpty(insertColumns)) {
......@@ -111,7 +111,7 @@ public abstract class BaseInsertExecutor<T, S extends Statement> extends Abstrac
for (int paramIdx = 0; paramIdx < insertColumnsSize; paramIdx++) {
String sqlColumnName = insertColumns.get(paramIdx);
if (containPK(sqlColumnName)) {
pkIndexMap.put(getStandardColumnName(sqlColumnName),paramIdx);
pkIndexMap.put(getStandardPkColumnName(sqlColumnName), paramIdx);
}
}
return pkIndexMap;
......@@ -121,7 +121,7 @@ public abstract class BaseInsertExecutor<T, S extends Statement> extends Abstrac
for (Map.Entry<String, ColumnMeta> entry : allColumns.entrySet()) {
pkIndex++;
if (containPK(entry.getValue().getColumnName())) {
pkIndexMap.put(ColumnUtils.delEscape(entry.getValue().getColumnName(),getDbType()),pkIndex);
pkIndexMap.put(ColumnUtils.delEscape(entry.getValue().getColumnName(), getDbType()), pkIndex);
}
}
return pkIndexMap;
......@@ -132,21 +132,21 @@ public abstract class BaseInsertExecutor<T, S extends Statement> extends Abstrac
* parse primary key value from statement.
* @return
*/
protected Map<String,List<Object>> parsePkValuesFromStatement() {
protected Map<String, List<Object>> parsePkValuesFromStatement() {
// insert values including PK
SQLInsertRecognizer recognizer = (SQLInsertRecognizer) sqlRecognizer;
final Map<String, Integer> pkIndexMap = getPkIndex();
if (pkIndexMap.isEmpty()) {
throw new ShouldNeverHappenException("pkIndex is not found");
}
Map<String,List<Object>> pkValuesMap = new HashMap<>();
Map<String, List<Object>> pkValuesMap = new HashMap<>();
boolean ps = true;
if (statementProxy instanceof PreparedStatementProxy) {
PreparedStatementProxy preparedStatementProxy = (PreparedStatementProxy) statementProxy;
List<List<Object>> insertRows = recognizer.getInsertRows(pkIndexMap.values());
if (insertRows != null && !insertRows.isEmpty()) {
Map<Integer,ArrayList<Object>> parameters = preparedStatementProxy.getParameters();
Map<Integer, ArrayList<Object>> parameters = preparedStatementProxy.getParameters();
final int rowSize = insertRows.size();
int totalPlaceholderNum = -1;
for (List<Object> row : insertRows) {
......@@ -276,19 +276,16 @@ public abstract class BaseInsertExecutor<T, S extends Statement> extends Abstrac
* @param pkValues
* @return
*/
protected boolean checkPkValuesForMultiPk(Map<String,List<Object>> pkValues) {
protected boolean checkPkValuesForMultiPk(Map<String, List<Object>> pkValues) {
Set<String> pkNames = pkValues.keySet();
if (pkNames.isEmpty())
{
if (pkNames.isEmpty()) {
throw new ShouldNeverHappenException();
}
int rowSize = pkValues.get(pkNames.iterator().next()).size();
for (int i = 0;i < rowSize; i++)
{
for (int i = 0; i < rowSize; i++) {
int n = 0;
int m = 0;
for (String name : pkNames)
{
for (String name : pkNames) {
Object pkValue = pkValues.get(name).get(i);
if (pkValue instanceof Null) {
n++;
......@@ -297,24 +294,21 @@ public abstract class BaseInsertExecutor<T, S extends Statement> extends Abstrac
m++;
}
}
if (n > 1)
{
if (n > 1) {
return false;
}
if (m > 0)
{
if (m > 0) {
return false;
}
}
return true;
}
protected boolean checkPkValues(Map<String,List<Object>> pkValues, boolean ps) {
protected boolean checkPkValues(Map<String, List<Object>> pkValues, boolean ps) {
Set<String> pkNames = pkValues.keySet();
if (pkNames.size() == 1) {
return checkPkValuesForSinglePk(pkValues.get(pkNames.iterator().next()),ps);
}
else {
return checkPkValuesForSinglePk(pkValues.get(pkNames.iterator().next()), ps);
} else {
return checkPkValuesForMultiPk(pkValues);
}
}
......@@ -322,7 +316,7 @@ public abstract class BaseInsertExecutor<T, S extends Statement> extends Abstrac
/**
* check pk values for single pk
* @param pkValues
* @param ps true: is prepared statement. false: normal statement.
* @param ps true: is prepared statement. false: normal statement.
* @return true: support. false: not support.
*/
protected boolean checkPkValuesForSinglePk(List<Object> pkValues, boolean ps) {
......
......@@ -244,11 +244,11 @@ public abstract class BaseTransactionalExecutor<T, S extends Statement> implemen
/**
* get standard column name from user sql column name
* get standard pk column name from user sql column name
*
* @return
*/
protected String getStandardColumnName(String userColumnName) {
protected String getStandardPkColumnName(String userColumnName) {
String newUserColumnName = ColumnUtils.delEscape(userColumnName, getDbType());
for (String cn : getTableMeta().getPrimaryKeyOnlyName()) {
if (cn.toUpperCase().equals(newUserColumnName.toUpperCase())) {
......
......@@ -80,6 +80,7 @@ public class MultiUpdateExecutor<T, S extends Statement> extends AbstractDMLBase
sqlRecognizer = recognizer;
SQLUpdateRecognizer sqlUpdateRecognizer = (SQLUpdateRecognizer) recognizer;
List<String> updateColumns = sqlUpdateRecognizer.getUpdateColumns();
assertContainsPKColumnName(updateColumns);
updateColumnsSet.addAll(updateColumns);
if (noWhereCondition) {
continue;
......@@ -132,7 +133,7 @@ public class MultiUpdateExecutor<T, S extends Statement> extends AbstractDMLBase
String selectSQL = buildAfterImageSQL(tmeta, beforeImage);
ResultSet rs = null;
try (PreparedStatement pst = statementProxy.getConnection().prepareStatement(selectSQL);) {
SqlGenerateUtils.setParamForPk(beforeImage.pkRows(),getTableMeta().getPrimaryKeyOnlyName(),pst);
SqlGenerateUtils.setParamForPk(beforeImage.pkRows(), getTableMeta().getPrimaryKeyOnlyName(), pst);
rs = pst.executeQuery();
return TableRecords.buildRecords(tmeta, rs);
} finally {
......@@ -149,7 +150,7 @@ public class MultiUpdateExecutor<T, S extends Statement> extends AbstractDMLBase
updateColumnsSet.addAll(sqlUpdateRecognizer.getUpdateColumns());
}
StringBuilder prefix = new StringBuilder("SELECT ");
String suffix = " FROM " + getFromTableInSQL() + " WHERE " + SqlGenerateUtils.buildWhereConditionByPKs(tableMeta.getPrimaryKeyOnlyName(),beforeImage.pkRows().size(),getDbType());
String suffix = " FROM " + getFromTableInSQL() + " WHERE " + SqlGenerateUtils.buildWhereConditionByPKs(tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix.toString(), suffix);
if (ONLY_CARE_UPDATE_COLUMNS) {
if (!containsPK(new ArrayList<>(updateColumnsSet))) {
......
......@@ -73,6 +73,8 @@ public class UpdateExecutor<T, S extends Statement> extends AbstractDMLBaseExecu
private String buildBeforeImageSQL(TableMeta tableMeta, ArrayList<List<Object>> paramAppenderList) {
SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer;
List<String> updateColumns = recognizer.getUpdateColumns();
assertContainsPKColumnName(updateColumns);
StringBuilder prefix = new StringBuilder("SELECT ");
StringBuilder suffix = new StringBuilder(" FROM ").append(getFromTableInSQL());
String whereCondition = buildWhereCondition(recognizer, paramAppenderList);
......@@ -90,7 +92,6 @@ public class UpdateExecutor<T, S extends Statement> extends AbstractDMLBaseExecu
suffix.append(" FOR UPDATE");
StringJoiner selectSQLJoin = new StringJoiner(", ", prefix.toString(), suffix.toString());
if (ONLY_CARE_UPDATE_COLUMNS) {
List<String> updateColumns = recognizer.getUpdateColumns();
if (!containsPK(updateColumns)) {
selectSQLJoin.add(getColumnNamesInSQL(tableMeta.getEscapePkNameList(getDbType())));
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment