Alina Lozovskaia commited on
Commit
d95d4a1
·
1 Parent(s): 9b7814c

apply code style and quality checks to read_evals.py

Browse files
Files changed (1) hide show
  1. src/leaderboard/read_evals.py +27 -29
src/leaderboard/read_evals.py CHANGED
@@ -16,36 +16,36 @@ from src.display.formatting import make_clickable_model
16
  from src.display.utils import AutoEvalColumn, ModelType, Precision, Tasks, WeightType, parse_datetime
17
 
18
  # Configure logging
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
20
 
21
  @dataclass
22
  class EvalResult:
23
  # Also see src.display.utils.AutoEvalColumn for what will be displayed.
24
- eval_name: str # org_model_precision (uid)
25
- full_model: str # org/model (path on hub)
26
  org: Optional[str]
27
  model: str
28
- revision: str # commit hash, "" if main
29
  results: Dict[str, float]
30
  precision: Precision = Precision.Unknown
31
- model_type: ModelType = ModelType.Unknown # Pretrained, fine tuned, ...
32
  weight_type: WeightType = WeightType.Original
33
- architecture: str = "Unknown" # From config file
34
  license: str = "?"
35
  likes: int = 0
36
  num_params: int = 0
37
- date: str = "" # submission date of request file
38
  still_on_hub: bool = True
39
  is_merge: bool = False
40
  not_flagged: bool = False
41
  status: str = "FINISHED"
42
  # List of tags, initialized to a new empty list for each instance to avoid the pitfalls of mutable default arguments.
43
  tags: List[str] = field(default_factory=list)
44
-
45
-
46
  @classmethod
47
- def init_from_json_file(cls, json_filepath: str) -> 'EvalResult':
48
- with open(json_filepath, 'r') as fp:
49
  data = json.load(fp)
50
 
51
  config = data.get("config_general", {})
@@ -72,7 +72,7 @@ class EvalResult:
72
  model=model,
73
  results=results,
74
  precision=precision,
75
- revision=config.get("model_sha", "")
76
  )
77
 
78
  @staticmethod
@@ -118,9 +118,8 @@ class EvalResult:
118
 
119
  mean_acc = np.mean(accs) * 100.0
120
  results[task.benchmark] = mean_acc
121
-
122
- return results
123
 
 
124
 
125
  def update_with_request_file(self, requests_path):
126
  """Finds the relevant request file for the current model and updates info with it."""
@@ -130,17 +129,17 @@ class EvalResult:
130
  logging.warning(f"No request file for {self.org}/{self.model}")
131
  self.status = "FAILED"
132
  return
133
-
134
  with open(request_file, "r") as f:
135
  request = json.load(f)
136
-
137
  self.model_type = ModelType.from_str(request.get("model_type", "Unknown"))
138
  self.weight_type = WeightType[request.get("weight_type", "Original")]
139
  self.num_params = int(request.get("params", 0)) # Ensuring type safety
140
  self.date = request.get("submitted_time", "")
141
  self.architecture = request.get("architectures", "Unknown")
142
  self.status = request.get("status", "FAILED")
143
-
144
  except FileNotFoundError:
145
  self.status = "FAILED"
146
  logging.error(f"Request file: {request_file} not found for {self.org}/{self.model}")
@@ -154,7 +153,6 @@ class EvalResult:
154
  self.status = "FAILED"
155
  logging.error(f"Unexpected error {e} for {self.org}/{self.model}")
156
 
157
-
158
  def update_with_dynamic_file_dict(self, file_dict):
159
  """Update object attributes based on the provided dictionary, with error handling for missing keys and type validation."""
160
  # Default values set for optional or potentially missing keys.
@@ -162,11 +160,10 @@ class EvalResult:
162
  self.likes = int(file_dict.get("likes", 0)) # Ensure likes is treated as an integer
163
  self.still_on_hub = file_dict.get("still_on_hub", False) # Default to False if key is missing
164
  self.tags = file_dict.get("tags", [])
165
-
166
  # Calculate `flagged` only if 'tags' is not empty and avoid calculating each time
167
  self.not_flagged = not (any("flagged" in tag for tag in self.tags))
168
 
169
-
170
  def to_dict(self):
171
  """Converts the Eval Result to a dict compatible with our dataframe display"""
172
  average = sum([v for v in self.results.values() if v is not None]) / len(Tasks)
@@ -185,8 +182,10 @@ class EvalResult:
185
  AutoEvalColumn.likes.name: self.likes,
186
  AutoEvalColumn.params.name: self.num_params,
187
  AutoEvalColumn.still_on_hub.name: self.still_on_hub,
188
- AutoEvalColumn.merged.name: not( "merge" in self.tags if self.tags else False),
189
- AutoEvalColumn.moe.name: not ( ("moe" in self.tags if self.tags else False) or "moe" in self.full_model.lower()) ,
 
 
190
  AutoEvalColumn.not_flagged.name: self.not_flagged,
191
  }
192
 
@@ -194,16 +193,16 @@ class EvalResult:
194
  data_dict[task.value.col_name] = self.results[task.value.benchmark]
195
 
196
  return data_dict
197
-
198
 
199
  def get_request_file_for_model(requests_path, model_name, precision):
200
  """Selects the correct request file for a given model. Only keeps runs tagged as FINISHED"""
201
  requests_path = Path(requests_path)
202
  pattern = f"{model_name}_eval_request_*.json"
203
-
204
  # Using pathlib to find files matching the pattern
205
  request_files = list(requests_path.glob(pattern))
206
-
207
  # Sort the files by name in descending order to mimic 'reverse=True'
208
  request_files.sort(reverse=True)
209
 
@@ -214,7 +213,7 @@ def get_request_file_for_model(requests_path, model_name, precision):
214
  req_content = json.load(f)
215
  if req_content["status"] == "FINISHED" and req_content["precision"] == precision.split(".")[-1]:
216
  request_file = str(request_file)
217
-
218
  # Return empty string if no file found that matches criteria
219
  return request_file
220
 
@@ -223,9 +222,9 @@ def get_raw_eval_results(results_path: str, requests_path: str, dynamic_path: st
223
  """From the path of the results folder root, extract all needed info for results"""
224
  with open(dynamic_path) as f:
225
  dynamic_data = json.load(f)
226
-
227
  results_path = Path(results_path)
228
- model_files = list(results_path.rglob('results_*.json'))
229
  model_files.sort(key=lambda file: parse_datetime(file.stem.removeprefix("results_")))
230
 
231
  eval_results = {}
@@ -260,4 +259,3 @@ def get_raw_eval_results(results_path: str, requests_path: str, dynamic_path: st
260
  continue
261
 
262
  return results
263
-
 
16
  from src.display.utils import AutoEvalColumn, ModelType, Precision, Tasks, WeightType, parse_datetime
17
 
18
  # Configure logging
19
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
20
+
21
 
22
  @dataclass
23
  class EvalResult:
24
  # Also see src.display.utils.AutoEvalColumn for what will be displayed.
25
+ eval_name: str # org_model_precision (uid)
26
+ full_model: str # org/model (path on hub)
27
  org: Optional[str]
28
  model: str
29
+ revision: str # commit hash, "" if main
30
  results: Dict[str, float]
31
  precision: Precision = Precision.Unknown
32
+ model_type: ModelType = ModelType.Unknown # Pretrained, fine tuned, ...
33
  weight_type: WeightType = WeightType.Original
34
+ architecture: str = "Unknown" # From config file
35
  license: str = "?"
36
  likes: int = 0
37
  num_params: int = 0
38
+ date: str = "" # submission date of request file
39
  still_on_hub: bool = True
40
  is_merge: bool = False
41
  not_flagged: bool = False
42
  status: str = "FINISHED"
43
  # List of tags, initialized to a new empty list for each instance to avoid the pitfalls of mutable default arguments.
44
  tags: List[str] = field(default_factory=list)
45
+
 
46
  @classmethod
47
+ def init_from_json_file(cls, json_filepath: str) -> "EvalResult":
48
+ with open(json_filepath, "r") as fp:
49
  data = json.load(fp)
50
 
51
  config = data.get("config_general", {})
 
72
  model=model,
73
  results=results,
74
  precision=precision,
75
+ revision=config.get("model_sha", ""),
76
  )
77
 
78
  @staticmethod
 
118
 
119
  mean_acc = np.mean(accs) * 100.0
120
  results[task.benchmark] = mean_acc
 
 
121
 
122
+ return results
123
 
124
  def update_with_request_file(self, requests_path):
125
  """Finds the relevant request file for the current model and updates info with it."""
 
129
  logging.warning(f"No request file for {self.org}/{self.model}")
130
  self.status = "FAILED"
131
  return
132
+
133
  with open(request_file, "r") as f:
134
  request = json.load(f)
135
+
136
  self.model_type = ModelType.from_str(request.get("model_type", "Unknown"))
137
  self.weight_type = WeightType[request.get("weight_type", "Original")]
138
  self.num_params = int(request.get("params", 0)) # Ensuring type safety
139
  self.date = request.get("submitted_time", "")
140
  self.architecture = request.get("architectures", "Unknown")
141
  self.status = request.get("status", "FAILED")
142
+
143
  except FileNotFoundError:
144
  self.status = "FAILED"
145
  logging.error(f"Request file: {request_file} not found for {self.org}/{self.model}")
 
153
  self.status = "FAILED"
154
  logging.error(f"Unexpected error {e} for {self.org}/{self.model}")
155
 
 
156
  def update_with_dynamic_file_dict(self, file_dict):
157
  """Update object attributes based on the provided dictionary, with error handling for missing keys and type validation."""
158
  # Default values set for optional or potentially missing keys.
 
160
  self.likes = int(file_dict.get("likes", 0)) # Ensure likes is treated as an integer
161
  self.still_on_hub = file_dict.get("still_on_hub", False) # Default to False if key is missing
162
  self.tags = file_dict.get("tags", [])
163
+
164
  # Calculate `flagged` only if 'tags' is not empty and avoid calculating each time
165
  self.not_flagged = not (any("flagged" in tag for tag in self.tags))
166
 
 
167
  def to_dict(self):
168
  """Converts the Eval Result to a dict compatible with our dataframe display"""
169
  average = sum([v for v in self.results.values() if v is not None]) / len(Tasks)
 
182
  AutoEvalColumn.likes.name: self.likes,
183
  AutoEvalColumn.params.name: self.num_params,
184
  AutoEvalColumn.still_on_hub.name: self.still_on_hub,
185
+ AutoEvalColumn.merged.name: not ("merge" in self.tags if self.tags else False),
186
+ AutoEvalColumn.moe.name: not (
187
+ ("moe" in self.tags if self.tags else False) or "moe" in self.full_model.lower()
188
+ ),
189
  AutoEvalColumn.not_flagged.name: self.not_flagged,
190
  }
191
 
 
193
  data_dict[task.value.col_name] = self.results[task.value.benchmark]
194
 
195
  return data_dict
196
+
197
 
198
  def get_request_file_for_model(requests_path, model_name, precision):
199
  """Selects the correct request file for a given model. Only keeps runs tagged as FINISHED"""
200
  requests_path = Path(requests_path)
201
  pattern = f"{model_name}_eval_request_*.json"
202
+
203
  # Using pathlib to find files matching the pattern
204
  request_files = list(requests_path.glob(pattern))
205
+
206
  # Sort the files by name in descending order to mimic 'reverse=True'
207
  request_files.sort(reverse=True)
208
 
 
213
  req_content = json.load(f)
214
  if req_content["status"] == "FINISHED" and req_content["precision"] == precision.split(".")[-1]:
215
  request_file = str(request_file)
216
+
217
  # Return empty string if no file found that matches criteria
218
  return request_file
219
 
 
222
  """From the path of the results folder root, extract all needed info for results"""
223
  with open(dynamic_path) as f:
224
  dynamic_data = json.load(f)
225
+
226
  results_path = Path(results_path)
227
+ model_files = list(results_path.rglob("results_*.json"))
228
  model_files.sort(key=lambda file: parse_datetime(file.stem.removeprefix("results_")))
229
 
230
  eval_results = {}
 
259
  continue
260
 
261
  return results