import os import json import torch import torch.nn as nn from .model import AudioClassifier from ..utils.config import dict2cfg, cfg2dict from huggingface_hub import HfApi, create_repo, hf_hub_download class HFAudioClassifier(AudioClassifier): """Hugging Face compatible AudioClassifier model""" def __init__(self, config): if isinstance(config, dict): self.config = dict2cfg(config) super().__init__(self.config) @classmethod def from_pretrained(cls, model_id, cache_dir=None, map_location="cpu", strict=False): # Check if model_id is a local path is_local = os.path.exists(model_id) if is_local: # Load from local checkpoint config_file = os.path.join(model_id, "config.json") model_file = os.path.join(model_id, "pytorch_model.bin") else: # Download from HF Hub config_file = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir) model_file = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", cache_dir=cache_dir) # Read config config = None if os.path.exists(config_file): with open(config_file, "r", encoding="utf-8") as f: config = json.load(f) # Create model model = cls(config) # Load weights if os.path.exists(model_file): state_dict = torch.load(model_file, map_location=torch.device(map_location)) model.load_state_dict(state_dict, strict=strict) model.eval() else: raise FileNotFoundError(f"Model weights not found at {model_file}") return model def push_to_hub(self, repo_id, token=None, commit_message=None, private=False): """Push model and config to Hugging Face Hub. Args: repo_id (str): Repository ID on HuggingFace Hub (e.g., 'username/model-name') token (str, optional): HuggingFace token. If None, will use token from ~/.huggingface/token commit_message (str, optional): Commit message for the push private (bool, optional): Whether to make the repository private """ # Create repo if it doesn't exist api = HfApi() try: create_repo(repo_id, private=private, token=token, exist_ok=True) except Exception as e: print(f"Repository creation failed: {e}") return # Save config config = cfg2dict(self.config) with open("config.json", "w", encoding="utf-8") as f: json.dump(config, f, indent=2, sort_keys=True) # Save model weights torch.save(self.cpu().state_dict(), "pytorch_model.bin") self.to(self.device if hasattr(self, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu') # restore device # Push files to hub files_to_push = ["config.json", "pytorch_model.bin"] for file in files_to_push: api.upload_file( path_or_fileobj=file, path_in_repo=file, repo_id=repo_id, token=token, commit_message=commit_message or f"Upload {file}" ) os.remove(file) # Clean up local files def save_pretrained(self, save_directory: str, **kwargs): """Save model weights and configuration to a directory. Args: save_directory (str): Directory to save files in **kwargs: Additional arguments passed to save functions """ os.makedirs(save_directory, exist_ok=True) # Save config config = cfg2dict(self.config) config_file = os.path.join(save_directory, "config.json") with open(config_file, "w", encoding="utf-8") as f: json.dump(config, f, indent=2, sort_keys=True) # Save model weights model_file = os.path.join(save_directory, "pytorch_model.bin") torch.save(self.cpu().state_dict(), model_file) self.to(self.device if hasattr(self, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu') # restore device