Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New matching and event reporting module #84

Merged
merged 8 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 83 additions & 41 deletions reccmp/isledecomp/compare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@
from reccmp.isledecomp.parser import DecompCodebase
from reccmp.isledecomp.dir import walk_source_dir
from reccmp.isledecomp.types import EntityType
from reccmp.isledecomp.compare.event import create_logging_wrapper
from reccmp.isledecomp.compare.asm import ParseAsm
from reccmp.isledecomp.compare.asm.replacement import create_name_lookup
from reccmp.isledecomp.compare.asm.fixes import assert_fixup, find_effective_match
from reccmp.isledecomp.analysis import find_float_consts
from .match_msvc import (
match_symbols,
match_functions,
match_vtables,
match_static_variables,
match_variables,
match_strings,
)
from .db import EntityDb, ReccmpEntity, ReccmpMatch
from .diff import combined_diff, CombinedDiffOutput
from .lines import LinesDb
Expand Down Expand Up @@ -142,7 +151,7 @@ def _load_cvdump(self):

# Build the list of entries to insert to the DB.
# In the rare case we have duplicate symbols for an address, ignore them.
dataset = {}
seen_addrs = set()

batch = self._db.batch()

Expand All @@ -162,9 +171,11 @@ def _load_cvdump(self):
addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset)
sym.addr = addr

if addr in dataset:
if addr in seen_addrs:
continue

seen_addrs.add(addr)

# If this symbol is the final one in its section, we were not able to
# estimate its size because we didn't have the total size of that section.
# We can get this estimate now and assume that the final symbol occupies
Expand Down Expand Up @@ -262,51 +273,82 @@ def orig_bin_checker(addr: int) -> bool:
# If we have two functions that share the same name, and one is
# a lineref, we can match the nameref correctly because the lineref
# was already removed from consideration.
for fun in codebase.iter_line_functions():
assert fun.filename is not None
recomp_addr = self._lines_db.search_line(
fun.filename, fun.line_number, fun.end_line
)
if recomp_addr is not None:
self._db.set_function_pair(fun.offset, recomp_addr)
if fun.should_skip():
self._db.mark_stub(fun.offset)

for fun in codebase.iter_name_functions():
self._db.match_function(fun.offset, fun.name)
if fun.should_skip():
self._db.mark_stub(fun.offset)

for var in codebase.iter_variables():
if var.is_static and var.parent_function is not None:
self._db.match_static_variable(
var.offset, var.name, var.parent_function
with self._db.batch() as batch:
for fun in codebase.iter_line_functions():
assert fun.filename is not None
recomp_addr = self._lines_db.search_line(
fun.filename, fun.line_number, fun.end_line
)
if recomp_addr is not None:
batch.match(fun.offset, recomp_addr)
batch.set_recomp(
recomp_addr, type=EntityType.FUNCTION, stub=fun.should_skip()
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to unify these two functions? Not sure how much they are used independent of each other (maybe at least one of them can be removed)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be better to have set_orig here, as with all the other imports from annotations that follow. I don't remember why I put set_recomp here. The final result to the database will be the same unless this match fails.


with self._db.batch() as batch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not immediately obvious to me why you use separate batches for some, but not all types of insertions. If there is a good reason, I'd add a comment explaining why. If not, why don't we use one big batch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I had these in different batches at one point and using insert_orig to match the previous code better. I'll switch this around tomorrow.

The reason to use separate batches is so that staging data with insert_orig would succeed once (in the first batch) and then not change the data in subsequent batches. If you keep calling insert_orig on the same address in the same batch, we modify the pending changes. This is by design so you can add attributes in stages (or only if certain conditions are met) as we do here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do it all in one batch if the annotations read from the DecompCodebase were guaranteed to have a unique orig address. I don't remember if that's true or not. We detect if you repeat the same addr in two different annotations (in the linter) but I don't think we remove the dupes here.

Copy link
Collaborator

@jonschz jonschz Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a valid reason to do multiple batches. I'd just add that to the code as a comment since the pattern was not obvious to me.

for fun in codebase.iter_name_functions():
batch.set_orig(
fun.offset, type=EntityType.FUNCTION, stub=fun.should_skip()
)
else:
self._db.match_variable(var.offset, var.name)

for tbl in codebase.iter_vtables():
self._db.match_vtable(tbl.offset, tbl.name, tbl.base_class)
if fun.name.startswith("?"):
batch.set_orig(fun.offset, symbol=fun.name)
else:
batch.set_orig(fun.offset, name=fun.name)

for var in codebase.iter_variables():
batch.set_orig(var.offset, name=var.name, type=EntityType.DATA)
if var.is_static and var.parent_function is not None:
batch.set_orig(
var.offset, static_var=True, parent_function=var.parent_function
)

for string in codebase.iter_strings():
# Not that we don't trust you, but we're checking the string
# annotation to make sure it is accurate.
try:
# TODO: would presumably fail for wchar_t strings
orig = self.orig_bin.read_string(string.offset).decode("latin1")
string_correct = string.name == orig
except UnicodeDecodeError:
string_correct = False

if not string_correct:
logger.error(
"Data at 0x%x does not match string %s",
for tbl in codebase.iter_vtables():
batch.set_orig(
tbl.offset,
name=tbl.name,
base_class=tbl.base_class,
type=EntityType.VTABLE,
)

# For now, just redirect match alerts to the logger.
report = create_logging_wrapper(logger)

# Now match
match_symbols(self._db, report)
match_functions(self._db, report)
match_vtables(self._db, report)
match_static_variables(self._db, report)
match_variables(self._db, report)
disinvite marked this conversation as resolved.
Show resolved Hide resolved

with self._db.batch() as batch:
for string in codebase.iter_strings():
# Not that we don't trust you, but we're checking the string
# annotation to make sure it is accurate.
try:
# TODO: would presumably fail for wchar_t strings
orig = self.orig_bin.read_string(string.offset).decode("latin1")
string_correct = string.name == orig
except UnicodeDecodeError:
string_correct = False

if not string_correct:
logger.error(
"Data at 0x%x does not match string %s",
string.offset,
repr(string.name),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a report? Maybe it also makes sense to go over all remaining log calls and see which of those should be reports.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I just added report calls to the alerts from load_markers. We can do the ones from other functions in upcoming PRs.

continue

batch.set_orig(
string.offset,
repr(string.name),
name=string.name,
type=EntityType.STRING,
size=len(string.name),
)
continue
# self._db.match_string(string.offset, string.name)

self._db.match_string(string.offset, string.name)
match_strings(self._db, report)

def _match_array_elements(self):
"""
Expand Down
120 changes: 14 additions & 106 deletions reccmp/isledecomp/compare/db.py
disinvite marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
matched int as (orig_addr is not null and recomp_addr is not null),
kvstore text default '{}'
);

CREATE VIEW orig_unmatched (orig_addr, kvstore) AS
SELECT orig_addr, kvstore FROM entities
WHERE orig_addr is not null and recomp_addr is null
ORDER by orig_addr;

CREATE VIEW recomp_unmatched (recomp_addr, kvstore) AS
SELECT recomp_addr, kvstore FROM entities
WHERE recomp_addr is not null and orig_addr is null
ORDER by recomp_addr;
Comment on lines +20 to +28
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you had good reasons to directly add SQL to match_msvc.py instead of writing abstractions for all the SQL queries. If you don't intend to add abstractions, how about adding a comment here that these views are used in other files? At first glance, db.py may look like it is self-contained (and I think it was before).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those views are just for convenience so we can exclude this boilerplate from all the queries:

where defined_addr is not null
and other_addr is null
order by defined_addr

We still have some abstraction: queries are used only in db.py and each of the matching functions. The DB was self contained until now, but only because the matching functions were there and not in their own module. I think it makes sense to split those up so we can match new entity types and things specific to other compilers and not have the DB class get unmanageable.

Using queries directly in the match functions was just a way to get this new idea moving. I don't think single use-case functions like get_unmatched_orig_function_names are obviously better than a SQL query. Less-specific utility functions are slower because you're doing less work at the sqlite level. The tradeoff is that more queries flying around locks us into this schema, and if we ever move off of the JSON string then it's just more stuff to rewrite. Something to revisit later.

"""


Expand Down Expand Up @@ -238,6 +248,10 @@ def sql(self) -> sqlite3.Connection:
def batch(self) -> EntityBatch:
return EntityBatch(self)

def count(self) -> int:
(count,) = self._sql.execute("SELECT count(1) from entities").fetchone()
return count

def set_orig_symbol(self, addr: int, **kwargs):
self.bulk_orig_insert(iter([(addr, kwargs)]))

Expand Down Expand Up @@ -621,112 +635,6 @@ def get_next_orig_addr(self, addr: int) -> int | None:

return result[0] if result is not None else None

def match_function(self, addr: int, name: str) -> bool:
did_match = self._match_on(EntityType.FUNCTION, addr, name)
if not did_match:
logger.error(
"Failed to find function symbol with annotation 0x%x and name '%s'",
addr,
name,
)

return did_match

def match_vtable(
self, addr: int, class_name: str, base_class: str | None = None
) -> bool:
"""Match the vtable for the given class name. If a base class is provided,
we will match the multiple inheritance vtable instead.

As with other name-based searches, set the given address on the first unmatched result.

Our search here depends on having already demangled the vtable symbol before
loading the data. For example: we want to search for "Pizza::`vftable'"
so we extract the class name from its symbol "??_7Pizza@@6B@".

For multiple inheritance, the vtable name references the base class like this:

- X::`vftable'{for `Y'}

The vtable for the derived class will take one of these forms:

- X::`vftable'{for `X'}
- X::`vftable'

We assume only one of the above will appear for a given class."""
# Most classes will not use multiple inheritance, so try the regular vtable
# first, unless a base class is provided.
if base_class is None or base_class == class_name:
bare_vftable = f"{class_name}::`vftable'"

for obj in self.search_name(bare_vftable, EntityType.VTABLE):
if obj.orig_addr is None and obj.recomp_addr is not None:
return self.set_pair(addr, obj.recomp_addr, EntityType.VTABLE)

# If we didn't find a match above, search for the multiple inheritance vtable.
for_name = base_class if base_class is not None else class_name
for_vftable = f"{class_name}::`vftable'{{for `{for_name}'}}"

for obj in self.search_name(for_vftable, EntityType.VTABLE):
if obj.orig_addr is None and obj.recomp_addr is not None:
return self.set_pair(addr, obj.recomp_addr, EntityType.VTABLE)

logger.error(
"Failed to find vtable for class with annotation 0x%x and name '%s'",
addr,
class_name,
)
return False

def match_static_variable(
self, addr: int, variable_name: str, function_addr: int
) -> bool:
"""Matching a static function variable by combining the variable name
with the decorated (mangled) name of its parent function."""

result = self._sql.execute(
"SELECT json_extract(kvstore, '$.name'), json_extract(kvstore, '$.symbol') FROM entities WHERE orig_addr = ?",
(function_addr,),
).fetchone()

if result is None:
logger.error("No function for static variable: %s", variable_name)
return False

# Get the friendly name for the "failed to match" error message
(function_name, function_symbol) = result

# If the static variable has a symbol, it will contain the parent function's symbol.
# e.g. Static variable "g_startupDelay" from function "IsleApp::Tick"
# The function symbol is: "?Tick@IsleApp@@QAEXH@Z"
# The variable symbol is: "?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA"
for (recomp_addr,) in self._sql.execute(
"""SELECT recomp_addr FROM entities
WHERE orig_addr IS NULL
AND (json_extract(kvstore, '$.type') = ? OR json_extract(kvstore, '$.type') IS NULL)
AND json_extract(kvstore, '$.symbol') LIKE '%' || ? || '%' || ? || '%'""",
(EntityType.DATA, variable_name, function_symbol),
):
return self.set_pair(addr, recomp_addr, EntityType.DATA)

logger.error(
"Failed to match static variable %s from function %s annotated with 0x%x",
variable_name,
function_name,
addr,
)

return False

def match_variable(self, addr: int, name: str) -> bool:
did_match = self._match_on(EntityType.DATA, addr, name) or self._match_on(
EntityType.POINTER, addr, name
)
if not did_match:
logger.error("Failed to find variable annotated with 0x%x: %s", addr, name)

return did_match

def match_string(self, addr: int, value: str) -> bool:
did_match = self._match_on(EntityType.STRING, addr, value)
if not did_match:
Expand Down
48 changes: 48 additions & 0 deletions reccmp/isledecomp/compare/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import enum
import logging
from typing import Protocol


class LoggingSeverity(enum.IntEnum):
"""To improve type checking. There isn't an enum to import from the logging module."""

DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR


class ReccmpEvent(enum.Enum):
NO_MATCH = enum.auto()

# Symbol (or designated unique attribute) was found not to be unique
NON_UNIQUE_SYMBOL = enum.auto()

# Match by name/type not unique
AMBIGUOUS_MATCH = enum.auto()


def event_to_severity(event: ReccmpEvent) -> LoggingSeverity:
return {
ReccmpEvent.NO_MATCH: LoggingSeverity.ERROR,
ReccmpEvent.NON_UNIQUE_SYMBOL: LoggingSeverity.WARNING,
ReccmpEvent.AMBIGUOUS_MATCH: LoggingSeverity.WARNING,
}.get(event, LoggingSeverity.INFO)


class ReccmpReportProtocol(Protocol):
disinvite marked this conversation as resolved.
Show resolved Hide resolved
def __call__(self, event: ReccmpEvent, orig_addr: int, /, msg: str = ""):
...


def reccmp_report_nop(*_, **__):
"""Reporting no-op function"""


def create_logging_wrapper(logger: logging.Logger) -> ReccmpReportProtocol:
"""Return a function to use when you just want to redirect events to the given logger"""

def wrap(event: ReccmpEvent, _: int, msg: str = ""):
logger.log(event_to_severity(event), msg)

return wrap
Loading