Update gptx_tokenizer.py
Browse files- gptx_tokenizer.py +53 -23
gptx_tokenizer.py
CHANGED
@@ -62,7 +62,7 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
62 |
f"<placeholder_tok_{i}>" for i in range(256)
|
63 |
]
|
64 |
|
65 |
-
|
66 |
if not os.path.isfile(config_path):
|
67 |
config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
|
68 |
if not config_path:
|
@@ -74,43 +74,73 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
|
|
74 |
def instantiate_from_file_or_name(self, model_file_or_name: str, repo_id: str = None):
|
75 |
"""
|
76 |
Load the tokenizer model from a file or download it from a repository.
|
|
|
77 |
Args:
|
78 |
model_file_or_name (str): Path to the model file or the model name.
|
79 |
repo_id (str, optional): Repository ID from which to download the model file.
|
|
|
80 |
Returns:
|
81 |
spm.SentencePieceProcessor: Loaded SentencePieceProcessor instance.
|
|
|
82 |
Raises:
|
83 |
ValueError: If repo_id is not provided when model_file_or_name is not a file.
|
84 |
OSError: If the model file cannot be loaded or downloaded.
|
85 |
"""
|
86 |
if not os.path.isfile(model_file_or_name):
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
try:
|
91 |
-
# List all files in the repo
|
92 |
-
repo_files = list_repo_files(repo_id)
|
93 |
-
|
94 |
-
# Find the tokenizer model file
|
95 |
-
tokenizer_files = [f for f in repo_files if f.endswith('.model')]
|
96 |
-
if not tokenizer_files:
|
97 |
-
raise FileNotFoundError(f"No .model file found in repository {repo_id}")
|
98 |
-
|
99 |
-
# Use the first .model file found
|
100 |
-
model_file = tokenizer_files[0]
|
101 |
-
print(f"Found tokenizer model file: {model_file}")
|
102 |
-
|
103 |
-
# Download the file
|
104 |
-
model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
|
105 |
-
print(f"Downloaded tokenizer model to: {model_file_or_name}")
|
106 |
-
except Exception as e:
|
107 |
-
raise OSError(f"Failed to download tokenizer model: {str(e)}")
|
108 |
|
109 |
try:
|
110 |
return spm.SentencePieceProcessor(model_file=model_file_or_name)
|
111 |
except Exception as e:
|
112 |
raise OSError(f"Failed to load tokenizer model: {str(e)}")
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
def __init__(
|
115 |
self,
|
116 |
model_path: Optional[str] = None,
|
|
|
62 |
f"<placeholder_tok_{i}>" for i in range(256)
|
63 |
]
|
64 |
|
65 |
+
def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Optional[Path]:
|
66 |
if not os.path.isfile(config_path):
|
67 |
config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
|
68 |
if not config_path:
|
|
|
74 |
def instantiate_from_file_or_name(self, model_file_or_name: str, repo_id: str = None):
|
75 |
"""
|
76 |
Load the tokenizer model from a file or download it from a repository.
|
77 |
+
|
78 |
Args:
|
79 |
model_file_or_name (str): Path to the model file or the model name.
|
80 |
repo_id (str, optional): Repository ID from which to download the model file.
|
81 |
+
|
82 |
Returns:
|
83 |
spm.SentencePieceProcessor: Loaded SentencePieceProcessor instance.
|
84 |
+
|
85 |
Raises:
|
86 |
ValueError: If repo_id is not provided when model_file_or_name is not a file.
|
87 |
OSError: If the model file cannot be loaded or downloaded.
|
88 |
"""
|
89 |
if not os.path.isfile(model_file_or_name):
|
90 |
+
model_file_or_name = try_to_load_from_cache(repo_id=repo_id, filename=Path(model_file_or_name).name)
|
91 |
+
if not model_file_or_name:
|
92 |
+
model_file_or_name = self._download_model_from_hub(repo_id=repo_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
try:
|
95 |
return spm.SentencePieceProcessor(model_file=model_file_or_name)
|
96 |
except Exception as e:
|
97 |
raise OSError(f"Failed to load tokenizer model: {str(e)}")
|
98 |
+
|
99 |
+
def _download_model_from_hub(self, repo_id: str) -> Optional[str]:
|
100 |
+
try:
|
101 |
+
# List all files in the repo
|
102 |
+
repo_files = list_repo_files(repo_id)
|
103 |
+
|
104 |
+
# Find the tokenizer model file
|
105 |
+
tokenizer_files = [f for f in repo_files if f.endswith('.model')]
|
106 |
+
if not tokenizer_files:
|
107 |
+
raise FileNotFoundError(f"No .model file found in repository {repo_id}")
|
108 |
+
|
109 |
+
# Use the first .model file found
|
110 |
+
model_file = tokenizer_files[0]
|
111 |
+
print(f"Found tokenizer model file: {model_file}")
|
112 |
+
|
113 |
+
# Download the file
|
114 |
+
model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
|
115 |
+
print(f"Downloaded tokenizer model to: {model_file_or_name}")
|
116 |
+
except Exception as e:
|
117 |
+
raise OSError(f"Failed to download tokenizer model: {str(e)}")
|
118 |
+
|
119 |
+
return model_file_or_name
|
120 |
+
|
121 |
+
def _download_config_from_hub(self, repo_id: str):
|
122 |
+
if repo_id is None:
|
123 |
+
raise ValueError("repo_id must be provided if config_path is not a local file")
|
124 |
+
|
125 |
+
try:
|
126 |
+
# List all files in the repo
|
127 |
+
repo_files = list_repo_files(repo_id)
|
128 |
+
|
129 |
+
# Find the tokenizer config file
|
130 |
+
tokenizer_files = [f for f in repo_files if f.endswith('tokenizer_config.json')]
|
131 |
+
if not tokenizer_files:
|
132 |
+
raise FileNotFoundError(f"No tokenizer_config.json file found in repository {repo_id}")
|
133 |
+
|
134 |
+
# Use the first tokenizer_config.json file found
|
135 |
+
tokenizer_config_file = tokenizer_files[0]
|
136 |
+
print(f"Found tokenizer config file: {tokenizer_config_file}")
|
137 |
+
|
138 |
+
# Download the file
|
139 |
+
tokenizer_config_file_or_name = hf_hub_download(repo_id=repo_id, filename=tokenizer_config_file)
|
140 |
+
print(f"Downloaded tokenizer config file to: {tokenizer_config_file_or_name}")
|
141 |
+
return tokenizer_config_file_or_name
|
142 |
+
except Exception as e:
|
143 |
+
raise OSError(f"Failed to download tokenizer model: {str(e)}")
|
144 |
def __init__(
|
145 |
self,
|
146 |
model_path: Optional[str] = None,
|