forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcheck_modular_conversion.py
177 lines (152 loc) · 7.14 KB
/
check_modular_conversion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import argparse
import difflib
import glob
import logging
import subprocess
from io import StringIO
from create_dependency_mapping import find_priority_list
# Console for rich printing
from modular_model_converter import convert_modular_file
from rich.console import Console
from rich.syntax import Syntax
logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
console = Console()
def process_file(modular_file_path, generated_modeling_content, file_type="modeling_", fix_and_overwrite=False):
file_name_prefix = file_type.split("*")[0]
file_name_suffix = file_type.split("*")[-1] if "*" in file_type else ""
file_path = modular_file_path.replace("modular_", f"{file_name_prefix}_").replace(".py", f"{file_name_suffix}.py")
# Read the actual modeling file
with open(file_path, "r") as modeling_file:
content = modeling_file.read()
output_buffer = StringIO(generated_modeling_content[file_type][0])
output_buffer.seek(0)
output_content = output_buffer.read()
diff = difflib.unified_diff(
output_content.splitlines(),
content.splitlines(),
fromfile=f"{file_path}_generated",
tofile=f"{file_path}",
lineterm="",
)
diff_list = list(diff)
# Check for differences
if diff_list:
if fix_and_overwrite:
with open(file_path, "w") as modeling_file:
modeling_file.write(generated_modeling_content[file_type][0])
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
else:
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
diff_text = "\n".join(diff_list)
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
console.print(syntax)
return 1
else:
console.print(f"[bold green]No differences found for {file_path}.[/bold green]")
return 0
def compare_files(modular_file_path, fix_and_overwrite=False):
# Generate the expected modeling content
generated_modeling_content = convert_modular_file(modular_file_path)
diff = 0
for file_type in generated_modeling_content.keys():
diff += process_file(modular_file_path, generated_modeling_content, file_type, fix_and_overwrite)
return diff
def get_models_in_diff():
"""
Finds all models that have been modified in the diff.
Returns:
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
"""
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
modified_files = (
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
.decode("utf-8")
.split()
)
# Matches both modelling files and tests
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
model_names = set()
for file_path in relevant_modified_files:
model_name = file_path.split("/")[-2]
model_names.add(model_name)
return model_names
def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
"""
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
Model is in the diff -> not guaranteed to have no differences
Dependency is in the diff -> not guaranteed to have no differences
Otherwise -> guaranteed to have no differences
Args:
modular_file_path: The path to the modular file.
dependencies: A dictionary containing the dependencies of each modular file.
models_in_diff: A set containing the names of the models that have been modified.
Returns:
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
"""
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
if model_name in models_in_diff:
return False
for dep in dependencies[modular_file_path]:
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
dependency_model_name = dep.split(".")[-2]
if dependency_model_name in models_in_diff:
return False
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
parser.add_argument(
"--files", default=["all"], type=list, nargs="+", help="List of modular_xxx.py files to compare."
)
parser.add_argument(
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
)
parser.add_argument(
"--num_workers",
default=1,
type=int,
help="The number of workers to run. No effect if `fix_and_overwrite` is specified.",
)
args = parser.parse_args()
if args.files == ["all"]:
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
# script will do nothing.
models_in_diff = get_models_in_diff()
if not models_in_diff:
console.print("[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]")
exit(0)
skipped_models = set()
non_matching_files = 0
ordered_files, dependencies = find_priority_list(args.files)
if args.fix_and_overwrite or args.num_workers == 1:
for modular_file_path in ordered_files:
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
if is_guaranteed_no_diff:
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
skipped_models.add(model_name)
continue
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
else:
new_ordered_files = []
for modular_file_path in ordered_files:
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
if is_guaranteed_no_diff:
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
skipped_models.add(model_name)
else:
new_ordered_files.append(modular_file_path)
import multiprocessing
with multiprocessing.Pool(args.num_workers) as p:
outputs = p.map(compare_files, new_ordered_files)
for output in outputs:
non_matching_files += output
if non_matching_files and not args.fix_and_overwrite:
raise ValueError("Some diff and their modeling code did not match.")
if skipped_models:
console.print(
f"[bold green]Skipped {len(skipped_models)} models and their dependencies that are not in the diff: "
f"{', '.join(skipped_models)}[/bold green]"
)