Text Generation
Transformers
Safetensors
llama
text-generation-inference
Inference Endpoints
mfromm commited on
Commit
e3c7073
·
verified ·
1 Parent(s): 6c94e76

Update gptx_tokenizer.py

Browse files
Files changed (1) hide show
  1. gptx_tokenizer.py +53 -23
gptx_tokenizer.py CHANGED
@@ -62,7 +62,7 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
62
  f"<placeholder_tok_{i}>" for i in range(256)
63
  ]
64
 
65
- def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Optional[Path]:
66
  if not os.path.isfile(config_path):
67
  config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
68
  if not config_path:
@@ -74,43 +74,73 @@ class HFGPTXTokenizer(PreTrainedTokenizer):
74
  def instantiate_from_file_or_name(self, model_file_or_name: str, repo_id: str = None):
75
  """
76
  Load the tokenizer model from a file or download it from a repository.
 
77
  Args:
78
  model_file_or_name (str): Path to the model file or the model name.
79
  repo_id (str, optional): Repository ID from which to download the model file.
 
80
  Returns:
81
  spm.SentencePieceProcessor: Loaded SentencePieceProcessor instance.
 
82
  Raises:
83
  ValueError: If repo_id is not provided when model_file_or_name is not a file.
84
  OSError: If the model file cannot be loaded or downloaded.
85
  """
86
  if not os.path.isfile(model_file_or_name):
87
- if repo_id is None:
88
- raise ValueError("repo_id must be provided if model_file_or_name is not a local file")
89
-
90
- try:
91
- # List all files in the repo
92
- repo_files = list_repo_files(repo_id)
93
-
94
- # Find the tokenizer model file
95
- tokenizer_files = [f for f in repo_files if f.endswith('.model')]
96
- if not tokenizer_files:
97
- raise FileNotFoundError(f"No .model file found in repository {repo_id}")
98
-
99
- # Use the first .model file found
100
- model_file = tokenizer_files[0]
101
- print(f"Found tokenizer model file: {model_file}")
102
-
103
- # Download the file
104
- model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
105
- print(f"Downloaded tokenizer model to: {model_file_or_name}")
106
- except Exception as e:
107
- raise OSError(f"Failed to download tokenizer model: {str(e)}")
108
 
109
  try:
110
  return spm.SentencePieceProcessor(model_file=model_file_or_name)
111
  except Exception as e:
112
  raise OSError(f"Failed to load tokenizer model: {str(e)}")
113
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def __init__(
115
  self,
116
  model_path: Optional[str] = None,
 
62
  f"<placeholder_tok_{i}>" for i in range(256)
63
  ]
64
 
65
+ def find_tokenizer_config(self, config_path: Path, repo_id: str = None) -> Optional[Path]:
66
  if not os.path.isfile(config_path):
67
  config_path = try_to_load_from_cache(repo_id=repo_id, filename=Path(config_path).name)
68
  if not config_path:
 
74
  def instantiate_from_file_or_name(self, model_file_or_name: str, repo_id: str = None):
75
  """
76
  Load the tokenizer model from a file or download it from a repository.
77
+
78
  Args:
79
  model_file_or_name (str): Path to the model file or the model name.
80
  repo_id (str, optional): Repository ID from which to download the model file.
81
+
82
  Returns:
83
  spm.SentencePieceProcessor: Loaded SentencePieceProcessor instance.
84
+
85
  Raises:
86
  ValueError: If repo_id is not provided when model_file_or_name is not a file.
87
  OSError: If the model file cannot be loaded or downloaded.
88
  """
89
  if not os.path.isfile(model_file_or_name):
90
+ model_file_or_name = try_to_load_from_cache(repo_id=repo_id, filename=Path(model_file_or_name).name)
91
+ if not model_file_or_name:
92
+ model_file_or_name = self._download_model_from_hub(repo_id=repo_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  try:
95
  return spm.SentencePieceProcessor(model_file=model_file_or_name)
96
  except Exception as e:
97
  raise OSError(f"Failed to load tokenizer model: {str(e)}")
98
+
99
+ def _download_model_from_hub(self, repo_id: str) -> Optional[str]:
100
+ try:
101
+ # List all files in the repo
102
+ repo_files = list_repo_files(repo_id)
103
+
104
+ # Find the tokenizer model file
105
+ tokenizer_files = [f for f in repo_files if f.endswith('.model')]
106
+ if not tokenizer_files:
107
+ raise FileNotFoundError(f"No .model file found in repository {repo_id}")
108
+
109
+ # Use the first .model file found
110
+ model_file = tokenizer_files[0]
111
+ print(f"Found tokenizer model file: {model_file}")
112
+
113
+ # Download the file
114
+ model_file_or_name = hf_hub_download(repo_id=repo_id, filename=model_file)
115
+ print(f"Downloaded tokenizer model to: {model_file_or_name}")
116
+ except Exception as e:
117
+ raise OSError(f"Failed to download tokenizer model: {str(e)}")
118
+
119
+ return model_file_or_name
120
+
121
+ def _download_config_from_hub(self, repo_id: str):
122
+ if repo_id is None:
123
+ raise ValueError("repo_id must be provided if config_path is not a local file")
124
+
125
+ try:
126
+ # List all files in the repo
127
+ repo_files = list_repo_files(repo_id)
128
+
129
+ # Find the tokenizer config file
130
+ tokenizer_files = [f for f in repo_files if f.endswith('tokenizer_config.json')]
131
+ if not tokenizer_files:
132
+ raise FileNotFoundError(f"No tokenizer_config.json file found in repository {repo_id}")
133
+
134
+ # Use the first tokenizer_config.json file found
135
+ tokenizer_config_file = tokenizer_files[0]
136
+ print(f"Found tokenizer config file: {tokenizer_config_file}")
137
+
138
+ # Download the file
139
+ tokenizer_config_file_or_name = hf_hub_download(repo_id=repo_id, filename=tokenizer_config_file)
140
+ print(f"Downloaded tokenizer config file to: {tokenizer_config_file_or_name}")
141
+ return tokenizer_config_file_or_name
142
+ except Exception as e:
143
+ raise OSError(f"Failed to download tokenizer model: {str(e)}")
144
  def __init__(
145
  self,
146
  model_path: Optional[str] = None,