forked from KarelDO/xmc.dspy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_irera.py
94 lines (75 loc) · 2.8 KB
/
run_irera.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
os.environ["DSP_NOTEBOOK_CACHEDIR"] = os.path.join(".", "local_cache")
from dspy import Models
from src.data_loaders import load_data
from src.evaluators import create_evaluators
from src.programs import InferRetrieveRank
import argparse
def run_irera(state_path, dataset_name, do_validation, do_test):
# load data (all of these files needed for the config could be dumped separately in one folder)
(
_,
validation_examples,
test_examples,
_,
_,
_,
) = load_data(dataset_name)
# load program
program = InferRetrieveRank.load(state_path)
# Validate / Test
if do_validation:
print("validating final program...")
validation_evaluators = create_evaluators(validation_examples)
validation_rp50 = validation_evaluators["rp50"](program)
validation_rp10 = validation_evaluators["rp10"](program)
validation_rp5 = validation_evaluators["rp5"](program)
if do_test:
print("testing final program...")
test_evaluators = create_evaluators(test_examples)
test_rp10 = test_evaluators["rp10"](program)
test_rp5 = test_evaluators["rp5"](program)
if do_validation:
print("Final program validation_rp50: ", validation_rp50)
print("Final program validation_rp10: ", validation_rp10)
print("Final program validation_rp5: ", validation_rp5)
if do_test:
print("Final program test_rp10: ", test_rp10)
print("Final program test_rp5: ", test_rp5)
return program
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run Infer-Retrieve-Rank on an extreme multi-label classification (XMC) dataset."
)
# Add arguments
parser.add_argument("--state_path", type=str)
parser.add_argument("--lm_config_path", type=str)
parser.add_argument(
"--dataset_name",
type=str,
help="Specify the dataset",
)
parser.add_argument(
"--do_validation",
action="store_true",
help="Specify if validation results need to be calculated (default: False)",
)
parser.add_argument(
"--do_test",
action="store_true",
help="Specify if test results need to be calculated (default: False)",
)
# Parse the command-line arguments
args = parser.parse_args()
state_path = args.state_path
lm_config_path = args.lm_config_path
dataset_name = args.dataset_name
do_validation = args.do_validation
do_test = args.do_test
print("state_path: ", state_path)
print("lm_config_path: ", lm_config_path)
print("dataset_name: ", dataset_name)
print("do_validation: ", do_validation)
print("do_test: ", do_test)
Models(config_path=lm_config_path)
program = run_irera(state_path, dataset_name, do_validation, do_test)