|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import collections |
|
import json |
|
import os |
|
import re |
|
from contextlib import contextmanager |
|
from pathlib import Path |
|
|
|
from git import Repo |
|
|
|
|
|
|
|
PATH_TO_TRANFORMERS = "." |
|
|
|
|
|
all_model_test_files = [str(x) for x in Path("tests/models/").glob("**/**/test_modeling_*.py")] |
|
|
|
all_pipeline_test_files = [str(x) for x in Path("tests/pipelines/").glob("**/test_pipelines_*.py")] |
|
|
|
|
|
@contextmanager |
|
def checkout_commit(repo, commit_id): |
|
""" |
|
Context manager that checks out a commit in the repo. |
|
""" |
|
current_head = repo.head.commit if repo.head.is_detached else repo.head.ref |
|
|
|
try: |
|
repo.git.checkout(commit_id) |
|
yield |
|
|
|
finally: |
|
repo.git.checkout(current_head) |
|
|
|
|
|
def clean_code(content): |
|
""" |
|
Remove docstrings, empty line or comments from `content`. |
|
""" |
|
|
|
|
|
splits = content.split('\"\"\"') |
|
content = "".join(splits[::2]) |
|
splits = content.split("\'\'\'") |
|
|
|
content = "".join(splits[::2]) |
|
|
|
|
|
lines_to_keep = [] |
|
for line in content.split("\n"): |
|
|
|
line = re.sub("#.*$", "", line) |
|
if len(line) == 0 or line.isspace(): |
|
continue |
|
lines_to_keep.append(line) |
|
return "\n".join(lines_to_keep) |
|
|
|
|
|
def get_all_tests(): |
|
""" |
|
Return a list of paths to all test folders and files under `tests`. All paths are rooted at `tests`. |
|
|
|
- folders under `tests`: `tokenization`, `pipelines`, etc. The folder `models` is excluded. |
|
- folders under `tests/models`: `bert`, `gpt2`, etc. |
|
- test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc. |
|
""" |
|
test_root_dir = os.path.join(PATH_TO_TRANFORMERS, "tests") |
|
|
|
|
|
tests = os.listdir(test_root_dir) |
|
tests = sorted(filter(lambda x: os.path.isdir(x) or x.startswith("tests/test_"), [f"tests/{x}" for x in tests])) |
|
|
|
|
|
model_tests_folders = os.listdir(os.path.join(test_root_dir, "models")) |
|
model_test_folders = sorted(filter(os.path.isdir, [f"tests/models/{x}" for x in model_tests_folders])) |
|
|
|
tests.remove("tests/models") |
|
tests = model_test_folders + tests |
|
|
|
return tests |
|
|
|
|
|
def diff_is_docstring_only(repo, branching_point, filename): |
|
""" |
|
Check if the diff is only in docstrings in a filename. |
|
""" |
|
with checkout_commit(repo, branching_point): |
|
with open(filename, "r", encoding="utf-8") as f: |
|
old_content = f.read() |
|
|
|
with open(filename, "r", encoding="utf-8") as f: |
|
new_content = f.read() |
|
|
|
old_content_clean = clean_code(old_content) |
|
new_content_clean = clean_code(new_content) |
|
|
|
return old_content_clean == new_content_clean |
|
|
|
|
|
def get_modified_python_files(diff_with_last_commit=False): |
|
""" |
|
Return a list of python files that have been modified between: |
|
|
|
- the current head and the main branch if `diff_with_last_commit=False` (default) |
|
- the current head and its parent commit otherwise. |
|
""" |
|
repo = Repo(PATH_TO_TRANFORMERS) |
|
|
|
if not diff_with_last_commit: |
|
print(f"main is at {repo.refs.main.commit}") |
|
print(f"Current head is at {repo.head.commit}") |
|
|
|
branching_commits = repo.merge_base(repo.refs.main, repo.head) |
|
for commit in branching_commits: |
|
print(f"Branching commit: {commit}") |
|
return get_diff(repo, repo.head.commit, branching_commits) |
|
else: |
|
print(f"main is at {repo.head.commit}") |
|
parent_commits = repo.head.commit.parents |
|
for commit in parent_commits: |
|
print(f"Parent commit: {commit}") |
|
return get_diff(repo, repo.head.commit, parent_commits) |
|
|
|
|
|
def get_diff(repo, base_commit, commits): |
|
""" |
|
Get's the diff between one or several commits and the head of the repository. |
|
""" |
|
print("\n### DIFF ###\n") |
|
code_diff = [] |
|
for commit in commits: |
|
for diff_obj in commit.diff(base_commit): |
|
|
|
if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"): |
|
code_diff.append(diff_obj.b_path) |
|
|
|
elif diff_obj.change_type == "D" and diff_obj.a_path.endswith(".py"): |
|
code_diff.append(diff_obj.a_path) |
|
|
|
elif diff_obj.change_type in ["M", "R"] and diff_obj.b_path.endswith(".py"): |
|
|
|
if diff_obj.a_path != diff_obj.b_path: |
|
code_diff.extend([diff_obj.a_path, diff_obj.b_path]) |
|
else: |
|
|
|
if diff_is_docstring_only(repo, commit, diff_obj.b_path): |
|
print(f"Ignoring diff in {diff_obj.b_path} as it only concerns docstrings or comments.") |
|
else: |
|
code_diff.append(diff_obj.a_path) |
|
|
|
return code_diff |
|
|
|
|
|
def get_module_dependencies(module_fname): |
|
""" |
|
Get the dependencies of a module. |
|
""" |
|
with open(os.path.join(PATH_TO_TRANFORMERS, module_fname), "r", encoding="utf-8") as f: |
|
content = f.read() |
|
|
|
module_parts = module_fname.split(os.path.sep) |
|
imported_modules = [] |
|
|
|
|
|
relative_imports = re.findall(r"from\s+(\.+\S+)\s+import\s+([^\n]+)\n", content) |
|
relative_imports = [mod for mod, imp in relative_imports if "# tests_ignore" not in imp] |
|
for imp in relative_imports: |
|
level = 0 |
|
while imp.startswith("."): |
|
imp = imp[1:] |
|
level += 1 |
|
|
|
if len(imp) > 0: |
|
dep_parts = module_parts[: len(module_parts) - level] + imp.split(".") |
|
else: |
|
dep_parts = module_parts[: len(module_parts) - level] + ["__init__.py"] |
|
imported_module = os.path.sep.join(dep_parts) |
|
|
|
|
|
if not imported_module.endswith("transformers/__init__.py"): |
|
imported_modules.append(imported_module) |
|
|
|
|
|
|
|
|
|
direct_imports = re.findall(r"from\s+transformers\.(\S+)\s+import\s+([^\n]+)\n", content) |
|
direct_imports = [mod for mod, imp in direct_imports if "# tests_ignore" not in imp] |
|
for imp in direct_imports: |
|
import_parts = imp.split(".") |
|
dep_parts = ["src", "transformers"] + import_parts |
|
imported_modules.append(os.path.sep.join(dep_parts)) |
|
|
|
|
|
dependencies = [] |
|
for imported_module in imported_modules: |
|
if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f"{imported_module}.py")): |
|
dependencies.append(f"{imported_module}.py") |
|
elif os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, imported_module)) and os.path.isfile( |
|
os.path.sep.join([PATH_TO_TRANFORMERS, imported_module, "__init__.py"]) |
|
): |
|
dependencies.append(os.path.sep.join([imported_module, "__init__.py"])) |
|
return dependencies |
|
|
|
|
|
def get_test_dependencies(test_fname): |
|
""" |
|
Get the dependencies of a test file. |
|
""" |
|
with open(os.path.join(PATH_TO_TRANFORMERS, test_fname), "r", encoding="utf-8") as f: |
|
content = f.read() |
|
|
|
|
|
|
|
relative_imports = re.findall(r"from\s+(\.\S+)\s+import\s+([^\n]+)\n", content) |
|
relative_imports = [test for test, imp in relative_imports if "# tests_ignore" not in imp] |
|
|
|
def _convert_relative_import_to_file(relative_import): |
|
level = 0 |
|
while relative_import.startswith("."): |
|
level += 1 |
|
relative_import = relative_import[1:] |
|
|
|
directory = os.path.sep.join(test_fname.split(os.path.sep)[:-level]) |
|
return os.path.join(directory, f"{relative_import.replace('.', os.path.sep)}.py") |
|
|
|
dependencies = [_convert_relative_import_to_file(relative_import) for relative_import in relative_imports] |
|
return [f for f in dependencies if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f))] |
|
|
|
|
|
def create_reverse_dependency_tree(): |
|
""" |
|
Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files. |
|
""" |
|
modules = [ |
|
str(f.relative_to(PATH_TO_TRANFORMERS)) |
|
for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py") |
|
] |
|
module_edges = [(d, m) for m in modules for d in get_module_dependencies(m)] |
|
|
|
tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")] |
|
test_edges = [(d, t) for t in tests for d in get_test_dependencies(t)] |
|
|
|
return module_edges + test_edges |
|
|
|
|
|
def get_tree_starting_at(module, edges): |
|
""" |
|
Returns the tree starting at a given module following all edges in the following format: [module, [list of edges |
|
starting at module], [list of edges starting at the preceding level], ...] |
|
""" |
|
vertices_seen = [module] |
|
new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module] |
|
tree = [module] |
|
while len(new_edges) > 0: |
|
tree.append(new_edges) |
|
final_vertices = list({edge[1] for edge in new_edges}) |
|
vertices_seen.extend(final_vertices) |
|
new_edges = [edge for edge in edges if edge[0] in final_vertices and edge[1] not in vertices_seen] |
|
|
|
return tree |
|
|
|
|
|
def print_tree_deps_of(module, all_edges=None): |
|
""" |
|
Prints the tree of modules depending on a given module. |
|
""" |
|
if all_edges is None: |
|
all_edges = create_reverse_dependency_tree() |
|
tree = get_tree_starting_at(module, all_edges) |
|
|
|
|
|
|
|
lines = [(tree[0], tree[0])] |
|
for index in range(1, len(tree)): |
|
edges = tree[index] |
|
start_edges = {edge[0] for edge in edges} |
|
|
|
for start in start_edges: |
|
end_edges = {edge[1] for edge in edges if edge[0] == start} |
|
|
|
pos = 0 |
|
while lines[pos][1] != start: |
|
pos += 1 |
|
lines = lines[: pos + 1] + [(" " * (2 * index) + end, end) for end in end_edges] + lines[pos + 1 :] |
|
|
|
for line in lines: |
|
|
|
print(line[0]) |
|
|
|
|
|
def create_reverse_dependency_map(): |
|
""" |
|
Create the dependency map from module/test filename to the list of modules/tests that depend on it (even |
|
recursively). |
|
""" |
|
modules = [ |
|
str(f.relative_to(PATH_TO_TRANFORMERS)) |
|
for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py") |
|
] |
|
|
|
direct_deps = {m: get_module_dependencies(m) for m in modules} |
|
|
|
|
|
tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")] |
|
direct_deps.update({t: get_test_dependencies(t) for t in tests}) |
|
|
|
all_files = modules + tests |
|
|
|
|
|
something_changed = True |
|
while something_changed: |
|
something_changed = False |
|
for m in all_files: |
|
for d in direct_deps[m]: |
|
if d not in direct_deps: |
|
raise ValueError(f"KeyError:{d}. From {m}") |
|
for dep in direct_deps[d]: |
|
if dep not in direct_deps[m]: |
|
direct_deps[m].append(dep) |
|
something_changed = True |
|
|
|
|
|
reverse_map = collections.defaultdict(list) |
|
for m in all_files: |
|
if m.endswith("__init__.py"): |
|
reverse_map[m].extend(direct_deps[m]) |
|
for d in direct_deps[m]: |
|
reverse_map[d].append(m) |
|
|
|
return reverse_map |
|
|
|
|
|
|
|
|
|
SPECIAL_MODULE_TO_TEST_MAP = { |
|
"commands/add_new_model_like.py": "utils/test_add_new_model_like.py", |
|
"configuration_utils.py": "test_configuration_common.py", |
|
"convert_graph_to_onnx.py": "onnx/test_onnx.py", |
|
"data/data_collator.py": "trainer/test_data_collator.py", |
|
"deepspeed.py": "deepspeed/", |
|
"feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py", |
|
"feature_extraction_utils.py": "test_feature_extraction_common.py", |
|
"file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"], |
|
"image_processing_utils.py": ["test_image_processing_common.py", "utils/test_image_processing_utils.py"], |
|
"image_transforms.py": "test_image_transforms.py", |
|
"utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"], |
|
"utils/hub.py": "utils/test_hub_utils.py", |
|
"modelcard.py": "utils/test_model_card.py", |
|
"modeling_flax_utils.py": "test_modeling_flax_common.py", |
|
"modeling_tf_utils.py": ["test_modeling_tf_common.py", "utils/test_modeling_tf_core.py"], |
|
"modeling_utils.py": ["test_modeling_common.py", "utils/test_offline.py"], |
|
"models/auto/modeling_auto.py": [ |
|
"models/auto/test_modeling_auto.py", |
|
"models/auto/test_modeling_tf_pytorch.py", |
|
"models/bort/test_modeling_bort.py", |
|
"models/dit/test_modeling_dit.py", |
|
], |
|
"models/auto/modeling_flax_auto.py": "models/auto/test_modeling_flax_auto.py", |
|
"models/auto/modeling_tf_auto.py": [ |
|
"models/auto/test_modeling_tf_auto.py", |
|
"models/auto/test_modeling_tf_pytorch.py", |
|
"models/bort/test_modeling_tf_bort.py", |
|
], |
|
"models/gpt2/modeling_gpt2.py": [ |
|
"models/gpt2/test_modeling_gpt2.py", |
|
"models/megatron_gpt2/test_modeling_megatron_gpt2.py", |
|
], |
|
"models/dpt/modeling_dpt.py": [ |
|
"models/dpt/test_modeling_dpt.py", |
|
"models/dpt/test_modeling_dpt_hybrid.py", |
|
], |
|
"optimization.py": "optimization/test_optimization.py", |
|
"optimization_tf.py": "optimization/test_optimization_tf.py", |
|
"pipelines/__init__.py": all_pipeline_test_files + all_model_test_files, |
|
"pipelines/base.py": all_pipeline_test_files + all_model_test_files, |
|
"pipelines/text2text_generation.py": [ |
|
"pipelines/test_pipelines_text2text_generation.py", |
|
"pipelines/test_pipelines_summarization.py", |
|
"pipelines/test_pipelines_translation.py", |
|
], |
|
"pipelines/zero_shot_classification.py": "pipelines/test_pipelines_zero_shot.py", |
|
"testing_utils.py": "utils/test_skip_decorators.py", |
|
"tokenization_utils.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"], |
|
"tokenization_utils_base.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"], |
|
"tokenization_utils_fast.py": [ |
|
"test_tokenization_common.py", |
|
"tokenization/test_tokenization_utils.py", |
|
"tokenization/test_tokenization_fast.py", |
|
], |
|
"trainer.py": [ |
|
"trainer/test_trainer.py", |
|
"extended/test_trainer_ext.py", |
|
"trainer/test_trainer_distributed.py", |
|
"trainer/test_trainer_tpu.py", |
|
], |
|
"train_pt_utils.py": "trainer/test_trainer_utils.py", |
|
"utils/versions.py": "utils/test_versions_utils.py", |
|
} |
|
|
|
|
|
def module_to_test_file(module_fname): |
|
""" |
|
Returns the name of the file(s) where `module_fname` is tested. |
|
""" |
|
splits = module_fname.split(os.path.sep) |
|
|
|
|
|
short_name = os.path.sep.join(splits[2:]) |
|
if short_name in SPECIAL_MODULE_TO_TEST_MAP: |
|
test_file = SPECIAL_MODULE_TO_TEST_MAP[short_name] |
|
if isinstance(test_file, str): |
|
return f"tests/{test_file}" |
|
return [f"tests/{f}" for f in test_file] |
|
|
|
module_name = splits[-1] |
|
|
|
if module_name.endswith("_fast.py"): |
|
module_name = module_name.replace("_fast.py", ".py") |
|
|
|
|
|
if len(splits) >= 2 and splits[-2] == "pipelines": |
|
default_test_file = f"tests/pipelines/test_pipelines_{module_name}" |
|
return [default_test_file] + all_model_test_files |
|
|
|
elif len(splits) >= 2 and splits[-2] == "benchmark": |
|
return ["tests/benchmark/test_benchmark.py", "tests/benchmark/test_benchmark_tf.py"] |
|
|
|
elif len(splits) >= 2 and splits[-2] == "commands": |
|
return "tests/utils/test_cli.py" |
|
|
|
elif len(splits) >= 2 and splits[-2] == "onnx": |
|
return ["tests/onnx/test_features.py", "tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"] |
|
|
|
elif len(splits) > 0 and splits[0] == "utils": |
|
default_test_file = f"tests/repo_utils/test_{module_name}" |
|
elif len(splits) > 4 and splits[2] == "models": |
|
default_test_file = f"tests/models/{splits[3]}/test_{module_name}" |
|
elif len(splits) > 2 and splits[2].startswith("generation"): |
|
default_test_file = f"tests/generation/test_{module_name}" |
|
elif len(splits) > 2 and splits[2].startswith("trainer"): |
|
default_test_file = f"tests/trainer/test_{module_name}" |
|
else: |
|
default_test_file = f"tests/utils/test_{module_name}" |
|
|
|
if os.path.isfile(default_test_file): |
|
return default_test_file |
|
|
|
|
|
if "processing" in default_test_file: |
|
test_file = default_test_file.replace("processing", "processor") |
|
if os.path.isfile(test_file): |
|
return test_file |
|
|
|
|
|
|
|
|
|
EXPECTED_TEST_FILES_NEVER_TOUCHED = [ |
|
"tests/generation/test_framework_agnostic.py", |
|
"tests/mixed_int8/test_mixed_int8.py", |
|
"tests/pipelines/test_pipelines_common.py", |
|
"tests/sagemaker/test_single_node_gpu.py", |
|
"tests/sagemaker/test_multi_node_model_parallel.py", |
|
"tests/sagemaker/test_multi_node_data_parallel.py", |
|
"tests/test_pipeline_mixin.py", |
|
"tests/utils/test_doc_samples.py", |
|
] |
|
|
|
|
|
def _print_list(l): |
|
return "\n".join([f"- {f}" for f in l]) |
|
|
|
|
|
def sanity_check(): |
|
""" |
|
Checks that all test files can be touched by a modification in at least one module/utils. This test ensures that |
|
newly-added test files are properly mapped to some module or utils, so they can be run by the CI. |
|
""" |
|
|
|
all_files = [ |
|
str(p.relative_to(PATH_TO_TRANFORMERS)) |
|
for p in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py") |
|
] |
|
all_files += [ |
|
str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "utils").glob("**/*.py") |
|
] |
|
|
|
|
|
test_files_found = [] |
|
for f in all_files: |
|
test_f = module_to_test_file(f) |
|
if test_f is not None: |
|
if isinstance(test_f, str): |
|
test_files_found.append(test_f) |
|
else: |
|
test_files_found.extend(test_f) |
|
|
|
|
|
test_files = [] |
|
for test_f in test_files_found: |
|
if os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, test_f)): |
|
test_files.extend( |
|
[ |
|
str(p.relative_to(PATH_TO_TRANFORMERS)) |
|
for p in (Path(PATH_TO_TRANFORMERS) / test_f).glob("**/test*.py") |
|
] |
|
) |
|
else: |
|
test_files.append(test_f) |
|
|
|
|
|
existing_test_files = [ |
|
str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/test*.py") |
|
] |
|
not_touched_test_files = [f for f in existing_test_files if f not in test_files] |
|
|
|
should_be_tested = set(not_touched_test_files) - set(EXPECTED_TEST_FILES_NEVER_TOUCHED) |
|
if len(should_be_tested) > 0: |
|
raise ValueError( |
|
"The following test files are not currently associated with any module or utils files, which means they " |
|
f"will never get run by the CI:\n{_print_list(should_be_tested)}\n. Make sure the names of these test " |
|
"files match the name of the module or utils they are testing, or adapt the constant " |
|
"`SPECIAL_MODULE_TO_TEST_MAP` in `utils/tests_fetcher.py` to add them. If your test file is triggered " |
|
"separately and is not supposed to be run by the regular CI, add it to the " |
|
"`EXPECTED_TEST_FILES_NEVER_TOUCHED` constant instead." |
|
) |
|
|
|
|
|
def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, json_output_file=None): |
|
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit) |
|
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}") |
|
|
|
|
|
impacted_modules_map = create_reverse_dependency_map() |
|
impacted_files = modified_files.copy() |
|
for f in modified_files: |
|
if f in impacted_modules_map: |
|
impacted_files.extend(impacted_modules_map[f]) |
|
|
|
|
|
impacted_files = sorted(set(impacted_files)) |
|
print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}") |
|
|
|
|
|
if "setup.py" in impacted_files: |
|
test_files_to_run = ["tests"] |
|
repo_utils_launch = True |
|
else: |
|
|
|
test_files_to_run = [] |
|
for f in impacted_files: |
|
|
|
if f.startswith("tests/"): |
|
test_files_to_run.append(f) |
|
|
|
elif f.startswith("examples/pytorch"): |
|
test_files_to_run.append("examples/pytorch/test_pytorch_examples.py") |
|
test_files_to_run.append("examples/pytorch/test_accelerate_examples.py") |
|
elif f.startswith("examples/tensorflow"): |
|
test_files_to_run.append("examples/tensorflow/test_tensorflow_examples.py") |
|
elif f.startswith("examples/flax"): |
|
test_files_to_run.append("examples/flax/test_flax_examples.py") |
|
else: |
|
new_tests = module_to_test_file(f) |
|
if new_tests is not None: |
|
if isinstance(new_tests, str): |
|
test_files_to_run.append(new_tests) |
|
else: |
|
test_files_to_run.extend(new_tests) |
|
|
|
|
|
test_files_to_run = sorted(set(test_files_to_run)) |
|
|
|
test_files_to_run = [f for f in test_files_to_run if os.path.isfile(f) or os.path.isdir(f)] |
|
if filters is not None: |
|
filtered_files = [] |
|
for filter in filters: |
|
filtered_files.extend([f for f in test_files_to_run if f.startswith(filter)]) |
|
test_files_to_run = filtered_files |
|
repo_utils_launch = any(f.split(os.path.sep)[1] == "repo_utils" for f in test_files_to_run) |
|
|
|
if repo_utils_launch: |
|
repo_util_file = Path(output_file).parent / "test_repo_utils.txt" |
|
with open(repo_util_file, "w", encoding="utf-8") as f: |
|
f.write("tests/repo_utils") |
|
|
|
print(f"\n### TEST TO RUN ###\n{_print_list(test_files_to_run)}") |
|
if len(test_files_to_run) > 0: |
|
with open(output_file, "w", encoding="utf-8") as f: |
|
f.write(" ".join(test_files_to_run)) |
|
|
|
|
|
|
|
|
|
|
|
if "tests" in test_files_to_run: |
|
test_files_to_run = get_all_tests() |
|
|
|
if json_output_file is not None: |
|
test_map = {} |
|
for test_file in test_files_to_run: |
|
|
|
|
|
|
|
|
|
names = test_file.split(os.path.sep) |
|
if names[1] == "models": |
|
|
|
key = "/".join(names[1:3]) |
|
elif len(names) > 2 or not test_file.endswith(".py"): |
|
|
|
|
|
key = "/".join(names[1:2]) |
|
else: |
|
|
|
key = "common" |
|
|
|
if key not in test_map: |
|
test_map[key] = [] |
|
test_map[key].append(test_file) |
|
|
|
|
|
keys = sorted(test_map.keys()) |
|
test_map = {k: " ".join(sorted(test_map[k])) for k in keys} |
|
with open(json_output_file, "w", encoding="UTF-8") as fp: |
|
json.dump(test_map, fp, ensure_ascii=False) |
|
|
|
|
|
def filter_tests(output_file, filters): |
|
""" |
|
Reads the content of the output file and filters out all the tests in a list of given folders. |
|
|
|
Args: |
|
output_file (`str` or `os.PathLike`): The path to the output file of the tests fetcher. |
|
filters (`List[str]`): A list of folders to filter. |
|
""" |
|
if not os.path.isfile(output_file): |
|
print("No test file found.") |
|
return |
|
with open(output_file, "r", encoding="utf-8") as f: |
|
test_files = f.read().split(" ") |
|
|
|
if len(test_files) == 0 or test_files == [""]: |
|
print("No tests to filter.") |
|
return |
|
|
|
if test_files == ["tests"]: |
|
test_files = [os.path.join("tests", f) for f in os.listdir("tests") if f not in ["__init__.py"] + filters] |
|
else: |
|
test_files = [f for f in test_files if f.split(os.path.sep)[1] not in filters] |
|
|
|
with open(output_file, "w", encoding="utf-8") as f: |
|
f.write(" ".join(test_files)) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--sanity_check", action="store_true", help="Only test that all tests and modules are accounted for." |
|
) |
|
parser.add_argument( |
|
"--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run" |
|
) |
|
parser.add_argument( |
|
"--json_output_file", |
|
type=str, |
|
default="test_map.json", |
|
help="Where to store the tests to run in a dictionary format mapping test categories to test files", |
|
) |
|
parser.add_argument( |
|
"--diff_with_last_commit", |
|
action="store_true", |
|
help="To fetch the tests between the current commit and the last commit", |
|
) |
|
parser.add_argument( |
|
"--filters", |
|
type=str, |
|
nargs="*", |
|
default=["tests"], |
|
help="Only keep the test files matching one of those filters.", |
|
) |
|
parser.add_argument( |
|
"--filter_tests", |
|
action="store_true", |
|
help="Will filter the pipeline/repo utils tests outside of the generated list of tests.", |
|
) |
|
parser.add_argument( |
|
"--print_dependencies_of", |
|
type=str, |
|
help="Will only print the tree of modules depending on the file passed.", |
|
default=None, |
|
) |
|
args = parser.parse_args() |
|
if args.print_dependencies_of is not None: |
|
print_tree_deps_of(args.print_dependencies_of) |
|
elif args.sanity_check: |
|
sanity_check() |
|
elif args.filter_tests: |
|
filter_tests(args.output_file, ["pipelines", "repo_utils"]) |
|
else: |
|
repo = Repo(PATH_TO_TRANFORMERS) |
|
|
|
diff_with_last_commit = args.diff_with_last_commit |
|
if not diff_with_last_commit and not repo.head.is_detached and repo.head.ref == repo.refs.main: |
|
print("main branch detected, fetching tests against last commit.") |
|
diff_with_last_commit = True |
|
|
|
try: |
|
infer_tests_to_run( |
|
args.output_file, |
|
diff_with_last_commit=diff_with_last_commit, |
|
filters=args.filters, |
|
json_output_file=args.json_output_file, |
|
) |
|
filter_tests(args.output_file, ["repo_utils"]) |
|
except Exception as e: |
|
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.") |
|
with open(args.output_file, "w", encoding="utf-8") as f: |
|
if args.filters is None: |
|
f.write("./tests/") |
|
else: |
|
f.write(" ".join(args.filters)) |
|
|