Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve analysis speed (closes #8) #9

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 68 additions & 48 deletions src/main/java/com/lauriewired/malimite/database/SQLiteDBHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

import com.lauriewired.malimite.decompile.SyntaxParser;

public class SQLiteDBHandler {
public class SQLiteDBHandler {
private String url;
private static final Logger LOGGER = Logger.getLogger(SQLiteDBHandler.class.getName());
private Connection transaction;

/*
* SQLiteDBHandler dbHandler = new SQLiteDBHandler("mydatabase.db");
Expand All @@ -39,6 +40,16 @@ public SQLiteDBHandler(String dbPath, String dbName) {
}

private void initializeDatabase() {
try {
Connection transaction = DriverManager.getConnection(url);
transaction.setAutoCommit(false);
this.transaction = transaction;
} catch (SQLException e) {
String msg = "Failed to create database transaction connection";
LOGGER.log(Level.SEVERE, msg, e);
throw new RuntimeException(msg, e);
}

String sqlClasses = "CREATE TABLE IF NOT EXISTS Classes ("
+ "ClassName TEXT,"
+ "Functions TEXT,"
Expand Down Expand Up @@ -114,6 +125,10 @@ private void initializeDatabase() {
}
}

public Connection GetTransaction() {
return this.transaction;
}

public Map<String, List<String>> getAllClassesAndFunctions() {
Map<String, List<String>> classFunctionMap = new HashMap<>();
String sql = "SELECT ClassName, Functions, ExecutableName FROM Classes";
Expand Down Expand Up @@ -239,64 +254,72 @@ public void readClasses() {
}

public void updateFunctionDecompilation(String functionName, String className, String decompiledCode, String executableName) {
try (Connection conn = DriverManager.getConnection(url)) {
updateFunctionDecompilation(conn, functionName, className, decompiledCode, executableName);
} catch (SQLException e) {
LOGGER.log(Level.SEVERE, "Error updating function decompilation", e);
e.printStackTrace();
}
}

public void updateFunctionDecompilation(Connection transaction, String functionName, String className, String decompiledCode, String executableName) {
// First, clear all existing references for this function
clearFunctionReferences(functionName, className, executableName);
clearFunctionReferences(transaction, functionName, className, executableName);

// Update the function's decompilation code
String sql = "UPDATE Functions SET DecompilationCode = ? "
+ "WHERE FunctionName = ? AND ParentClass = ? AND ExecutableName = ?";
+ "WHERE FunctionName = ? AND ParentClass = ? AND ExecutableName = ?";

try (Connection conn = DriverManager.getConnection(url);
PreparedStatement pstmt = conn.prepareStatement(sql)) {
try (PreparedStatement pstmt = transaction.prepareStatement(sql)) {
pstmt.setString(1, decompiledCode);
pstmt.setString(2, functionName);
pstmt.setString(3, className);
pstmt.setString(4, executableName);
int rowsAffected = pstmt.executeUpdate();

if (rowsAffected == 0) {
// If no rows were updated, insert a new record
sql = "INSERT INTO Functions(FunctionName, ParentClass, DecompilationCode, ExecutableName) VALUES(?, ?, ?, ?)";
try (PreparedStatement insertStmt = conn.prepareStatement(sql)) {
try (PreparedStatement insertStmt = transaction.prepareStatement(sql)) {
insertStmt.setString(1, functionName);
insertStmt.setString(2, className);
insertStmt.setString(3, decompiledCode);
insertStmt.setString(4, executableName);
rowsAffected = insertStmt.executeUpdate();
}
}

// Create a new SyntaxParser and reparse the updated function
if (decompiledCode != null && !decompiledCode.trim().isEmpty()) {
SyntaxParser parser = new SyntaxParser(this, executableName);
parser.setContext(functionName, className);
parser.collectCrossReferences(decompiledCode);
}


transaction.commit();

LOGGER.info("Database update for " + functionName + " affected " + rowsAffected + " rows");
} catch (SQLException e) {
LOGGER.log(Level.SEVERE, "Error updating function decompilation", e);
e.printStackTrace();
}
}

private void clearFunctionReferences(String functionName, String className, String executableName) {
private void clearFunctionReferences(Connection transaction, String functionName, String className, String executableName) {
String sqlFuncRefs = "DELETE FROM FunctionReferences WHERE sourceFunction = ? AND sourceClass = ? AND ExecutableName = ?";
String sqlVarRefs = "DELETE FROM LocalVariableReferences WHERE containingFunction = ? AND containingClass = ? AND ExecutableName = ?";
String sqlTypeInfo = "DELETE FROM TypeInformation WHERE functionName = ? AND className = ? AND ExecutableName = ?";

try (Connection conn = DriverManager.getConnection(url)) {
for (String sql : new String[]{sqlFuncRefs, sqlVarRefs, sqlTypeInfo}) {
try (PreparedStatement pstmt = conn.prepareStatement(sql)) {
pstmt.setString(1, functionName);
pstmt.setString(2, className);
pstmt.setString(3, executableName);
pstmt.executeUpdate();
}

for (String sql : new String[]{sqlFuncRefs, sqlVarRefs, sqlTypeInfo}) {
try (PreparedStatement pstmt = transaction.prepareStatement(sql)) {
pstmt.setString(1, functionName);
pstmt.setString(2, className);
pstmt.setString(3, executableName);
pstmt.executeUpdate();
} catch (SQLException e) {
LOGGER.log(Level.SEVERE, "Error clearing function references", e);
e.printStackTrace();
}
} catch (SQLException e) {
LOGGER.log(Level.SEVERE, "Error clearing function references", e);
e.printStackTrace();
}
}

Expand Down Expand Up @@ -394,18 +417,17 @@ public List<Map<String, String>> getResourceStrings() {
return strings;
}

public void insertFunctionReference(String sourceFunction, String sourceClass,
String targetFunction, String targetClass, int lineNumber, String executableName) {
public void insertFunctionReference(Connection transaction, String sourceFunction, String sourceClass,
String targetFunction, String targetClass, int lineNumber, String executableName) {
String sql = "INSERT INTO FunctionReferences(sourceFunction, sourceClass, "
+ "targetFunction, targetClass, lineNumber, ExecutableName) "
+ "SELECT ?, ?, ?, ?, ?, ? "
+ "WHERE NOT EXISTS (SELECT 1 FROM FunctionReferences "
+ "WHERE sourceFunction = ? AND sourceClass = ? "
+ "AND targetFunction = ? AND targetClass = ? "
+ "AND lineNumber = ? AND ExecutableName = ?)";

try (Connection conn = DriverManager.getConnection(url);
PreparedStatement pstmt = conn.prepareStatement(sql)) {
+ "targetFunction, targetClass, lineNumber, ExecutableName) "
+ "SELECT ?, ?, ?, ?, ?, ? "
+ "WHERE NOT EXISTS (SELECT 1 FROM FunctionReferences "
+ "WHERE sourceFunction = ? AND sourceClass = ? "
+ "AND targetFunction = ? AND targetClass = ? "
+ "AND lineNumber = ? AND ExecutableName = ?)";

try (PreparedStatement pstmt = transaction.prepareStatement(sql)) {
// Parameters for INSERT
pstmt.setString(1, sourceFunction);
pstmt.setString(2, sourceClass);
Expand All @@ -426,17 +448,16 @@ public void insertFunctionReference(String sourceFunction, String sourceClass,
}
}

public void insertLocalVariableReference(String variableName, String containingFunction,
String containingClass, int lineNumber, String executableName) {
public void insertLocalVariableReference(Connection transaction, String variableName, String containingFunction,
String containingClass, int lineNumber, String executableName) {
String sql = "INSERT INTO LocalVariableReferences(variableName, containingFunction, "
+ "containingClass, lineNumber, ExecutableName) "
+ "SELECT ?, ?, ?, ?, ? "
+ "WHERE NOT EXISTS (SELECT 1 FROM LocalVariableReferences "
+ "WHERE variableName = ? AND containingFunction = ? "
+ "AND containingClass = ? AND lineNumber = ? AND ExecutableName = ?)";
+ "containingClass, lineNumber, ExecutableName) "
+ "SELECT ?, ?, ?, ?, ? "
+ "WHERE NOT EXISTS (SELECT 1 FROM LocalVariableReferences "
+ "WHERE variableName = ? AND containingFunction = ? "
+ "AND containingClass = ? AND lineNumber = ? AND ExecutableName = ?)";

try (Connection conn = DriverManager.getConnection(url);
PreparedStatement pstmt = conn.prepareStatement(sql)) {
try (PreparedStatement pstmt = transaction.prepareStatement(sql)) {
// Parameters for INSERT
pstmt.setString(1, variableName);
pstmt.setString(2, containingFunction);
Expand Down Expand Up @@ -510,13 +531,12 @@ public List<Map<String, String>> getTypeInformation(String functionName, String
return types;
}

public void insertTypeInformation(String variableName, String variableType,
String functionName, String className, int lineNumber, String executableName) {
public void insertTypeInformation(Connection transaction, String variableName, String variableType,
String functionName, String className, int lineNumber, String executableName) {
String sql = "INSERT INTO TypeInformation(variableName, variableType, functionName, "
+ "className, lineNumber, ExecutableName) VALUES(?,?,?,?,?,?)";
+ "className, lineNumber, ExecutableName) VALUES(?,?,?,?,?,?)";

try (Connection conn = DriverManager.getConnection(url);
PreparedStatement pstmt = conn.prepareStatement(sql)) {
try (PreparedStatement pstmt = transaction.prepareStatement(sql)) {
pstmt.setString(1, variableName);
pstmt.setString(2, variableType);
pstmt.setString(3, functionName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public void decompileMacho(String executableFilePath, String projectDirectoryPat
}

// Store function decompilation with the correct class name and executable name
dbHandler.updateFunctionDecompilation(functionName, className, decompiledCode, targetMacho.getMachoExecutableName());
dbHandler.updateFunctionDecompilation(dbHandler.GetTransaction(), functionName, className, decompiledCode, targetMacho.getMachoExecutableName());

// Add to class functions map
classToFunctions.computeIfAbsent(className, k -> new JSONArray())
Expand All @@ -243,7 +243,7 @@ public void decompileMacho(String executableFilePath, String projectDirectoryPat
// Store the mapping of original class name to "Libraries"
classNameMapping.put(className, "Libraries");

dbHandler.updateFunctionDecompilation(libraryFunctionName, "Libraries", targetMacho.getMachoExecutableName(), targetMacho.getMachoExecutableName());
dbHandler.updateFunctionDecompilation(dbHandler.GetTransaction(), libraryFunctionName, "Libraries", targetMacho.getMachoExecutableName(), targetMacho.getMachoExecutableName());

// Add to class functions map under "Libraries"
classToFunctions.computeIfAbsent("Libraries", k -> new JSONArray())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public Void visitPostfixExpression(CPP14Parser.PostfixExpressionContext ctx) {

// Store the function reference with adjusted line number
dbHandler.insertFunctionReference(
dbHandler.GetTransaction(),
currentFunction,
currentClass,
calledFunction,
Expand Down Expand Up @@ -137,6 +138,7 @@ public Void visitDeclarationStatement(CPP14Parser.DeclarationStatementContext ct

// Store the type information
dbHandler.insertTypeInformation(
dbHandler.GetTransaction(),
variableName,
variableType,
currentFunction,
Expand All @@ -147,6 +149,7 @@ public Void visitDeclarationStatement(CPP14Parser.DeclarationStatementContext ct

// Store initial local variable reference
dbHandler.insertLocalVariableReference(
dbHandler.GetTransaction(),
variableName,
currentFunction,
currentClass,
Expand All @@ -173,6 +176,7 @@ public Void visitIdExpression(CPP14Parser.IdExpressionContext ctx) {

// Store class usage reference
dbHandler.insertFunctionReference(
dbHandler.GetTransaction(),
currentFunction,
currentClass,
null, // No specific function
Expand All @@ -186,6 +190,7 @@ public Void visitIdExpression(CPP14Parser.IdExpressionContext ctx) {
// Check if this identifier is in a function call context
if (!isPartOfFunctionCall(ctx)) {
dbHandler.insertLocalVariableReference(
dbHandler.GetTransaction(),
identifier,
currentFunction,
currentClass,
Expand Down