diff --git a/src/main/java/io/mycat/route/util/RouterUtil.java b/src/main/java/io/mycat/route/util/RouterUtil.java index 52360308e..2789bcc9e 100644 --- a/src/main/java/io/mycat/route/util/RouterUtil.java +++ b/src/main/java/io/mycat/route/util/RouterUtil.java @@ -52,7 +52,6 @@ public class RouterUtil { private static final Logger LOGGER = LoggerFactory.getLogger(RouterUtil.class); - /** * 移除执行语句中的数据库名 * @@ -62,7 +61,6 @@ public class RouterUtil { * * @author mycat */ - public static String removeSchema(String stmt, String schema) { final String upStmt = stmt.toUpperCase(); final String upSchema = schema.toUpperCase() + "."; @@ -102,9 +100,15 @@ public static String removeSchema(String stmt, String schema) { private static int countChar(String sql,int end) { int count=0; + boolean skipChar = false; for (int i = 0; i < end; i++) { - if(sql.charAt(i)=='\'') { + if(sql.charAt(i)=='\'' && !skipChar) { count++; + skipChar = false; + }else if( sql.charAt(i)=='\\'){ + skipChar = true; + }else{ + skipChar = false; } } return count; diff --git a/src/test/java/io/mycat/route/util/RouterUtilTest.java b/src/test/java/io/mycat/route/util/RouterUtilTest.java index 949ceecd9..f33ece835 100644 --- a/src/test/java/io/mycat/route/util/RouterUtilTest.java +++ b/src/test/java/io/mycat/route/util/RouterUtilTest.java @@ -14,6 +14,9 @@ * @date 2016/7/19 */ public class RouterUtilTest { + + + @Test public void testBatchInsert() { String sql = "insert into hotnews(title,name) values('test1',\"name\"),('(test)',\"(test)\"),('\\\"',\"\\'\"),(\")\",\"\\\"\\')\");"; @@ -45,10 +48,34 @@ public void testRemoveSchemaSelect() { @Test public void testRemoveSchemaSelect2() { - String sql = "select id as 'aa' from testx.test where name='abcd testx.aa' and id=1 and testx=123"; + String sql = "select id as 'aa' from testx.test where name='abcd testx.aa' and id=1 and testx=123"; String afterAql= RouterUtil.removeSchema(sql,"testx"); Assert.assertNotSame(sql.indexOf("testx."),afterAql.indexOf("testx.")); } + + @Test + public void testRemoveSchema2(){ + String sql = "update testx.test set name='abcd \\' testx.aa' where id=1"; + String sqltrue = "update test set name='abcd \\' testx.aa' where id=1"; + String sqlnew = RouterUtil.removeSchema(sql, "testx"); + Assert.assertEquals("处理错误:", sqltrue, sqlnew); + } + + @Test + public void testRemoveSchema3(){ + String sql = "update testx.test set testx.name='abcd testx.aa' where testx.id=1"; + String sqltrue = "update test set name='abcd testx.aa' where id=1"; + String sqlnew = RouterUtil.removeSchema(sql, "testx"); + Assert.assertEquals("处理错误:", sqltrue, sqlnew); + } + + @Test + public void testRemoveSchema4(){ + String sql = "update testx.test set testx.name='abcd testx.aa' and testx.name2='abcd testx.aa' where testx.id=1"; + String sqltrue = "update test set name='abcd testx.aa' and name2='abcd testx.aa' where id=1"; + String sqlnew = RouterUtil.removeSchema(sql, "testx"); + Assert.assertEquals("处理错误:", sqltrue, sqlnew); + } }