Spaces:
Runtime error
Runtime error
File size: 1,555 Bytes
abca9bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import json
import logging
import sys
from typing import Tuple
from transformers import HfArgumentParser
from transformers.hf_argparser import DataClass
class CustomHfArgumentParser(HfArgumentParser):
def parse_dictionary_and_args(self) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
dataclass types.
"""
args = []
data = {}
for i in range(1, len(sys.argv)):
if not sys.argv[i].endswith('.json'):
break
with open(sys.argv[i]) as f:
new_data = json.load(f)
conflicting_keys = set(new_data.keys()).intersection(data.keys())
if len(conflicting_keys) > 0:
raise ValueError(f'There are conflicting keys in the config files: {conflicting_keys}')
data.update(new_data)
for k, v in data.items():
# if any options were given explicitly through the CLA then they override anything defined in the config files
if f'--{k}' in sys.argv:
logging.info(f'While {k}={v} was given in a config file, a manual override was set through the CLA')
continue
args.extend(
["--" + k, *(v if isinstance(v, list) else [str(v)])]
) # add the file arguments first so command line args has precedence
args += sys.argv[i:]
return self.parse_args_into_dataclasses(args=args, look_for_args_file=False) |