import os import json import gradio as gr import torch from TeLVE.imagine import ImageCaptioningModel, load_model, generate_caption from transformers import BertTokenizerFast from huggingface_hub import hf_hub_download, list_repo_files # Constants MODELS_DIR = "./TeLVE/models" TOKENIZER_PATH = "./TeLVE/tokenizer" HF_REPO_ID = "outsu/TeLVE" MODEL_STATE_FILE = "./TeLVE/model_state.json" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model_state(): """Load the model state from JSON file""" if os.path.exists(MODEL_STATE_FILE): with open(MODEL_STATE_FILE, 'r') as f: return json.load(f) return {"downloaded_models": []} def save_model_state(state): """Save the model state to JSON file""" os.makedirs(os.path.dirname(MODEL_STATE_FILE), exist_ok=True) with open(MODEL_STATE_FILE, 'w') as f: json.dump(state, f) def list_available_models(): """List all .pth models in the models directory in reverse order""" if not os.path.exists(MODELS_DIR): return [] models = [f for f in os.listdir(MODELS_DIR) if f.endswith('.pth')] return sorted(models, reverse=True) def get_hf_model_list(): """Get list of model files from HuggingFace repository in reverse order""" try: files = list_repo_files(HF_REPO_ID) models = [f.split('/')[-1] for f in files if f.startswith('models/') and f.endswith('.pth')] return sorted(models, reverse=True) except Exception as e: print(f"Error fetching models from HuggingFace: {str(e)}") return [] def download_missing_models(): """Download missing models from HuggingFace with state management""" if not os.path.exists(MODELS_DIR): os.makedirs(MODELS_DIR) state = load_model_state() downloaded_models = set(state["downloaded_models"]) local_models = set(list_available_models()) hf_models = set(get_hf_model_list()) # Check for models that need downloading models_to_download = (hf_models - local_models) - downloaded_models for model in models_to_download: try: print(f"Downloading missing model: {model}") hf_hub_download( repo_id=HF_REPO_ID, filename=f"models/{model}", local_dir=os.path.dirname(MODELS_DIR), local_dir_use_symlinks=False ) downloaded_models.add(model) except Exception as e: print(f"Error downloading {model}: {str(e)}") continue # Update state with newly downloaded models state["downloaded_models"] = list(downloaded_models) save_model_state(state) def verify_model_integrity(): """Verify that all models in state actually exist""" state = load_model_state() local_models = set(list_available_models()) state["downloaded_models"] = list(set(state["downloaded_models"]) & local_models) save_model_state(state) def generate_description(image, model_name): """Generate image caption using selected model""" try: # Load model and tokenizer model_path = os.path.join(MODELS_DIR, model_name) if not os.path.exists(model_path): return "Error: Selected model file not found." if not os.path.exists(TOKENIZER_PATH): return "Error: Tokenizer not found. Please make sure you have trained a model first." model = load_model(model_path) tokenizer = BertTokenizerFast.from_pretrained(TOKENIZER_PATH) # Generate caption caption = generate_caption(model, image, tokenizer) return caption except Exception as e: return f"Error occurred: {str(e)}" # Create Gradio interface def create_interface(): available_models = list_available_models() if not available_models: return gr.Interface( fn=lambda x: "No models found in ./models directory. Please train a model first.", inputs="image", outputs="text", title="TeLVE - Turkish efficient Language Vision Engine 🧿", description="Error: No models available" ) interface = gr.Interface( fn=generate_description, inputs=[ gr.Image(type="filepath", label="Upload Image"), gr.Dropdown(choices=available_models, label="Select Model", value=available_models[0]) ], outputs=gr.Textbox(label="Generated Caption"), title="TeLVE - Turkish efficient Language Vision Engine 🧿", description="Upload an image to generate a Turkish description.", examples=[ ["./images/mugla.jpg", available_models[0]], ["./images/example.jpg", available_models[0]] ] if os.path.exists("./images") else None ) return interface if __name__ == "__main__": print("Verifying model integrity...") verify_model_integrity() print("Checking for missing models...") download_missing_models() demo = create_interface() demo.launch(share=True, server_name="0.0.0.0")