from fastapi import FastAPI, Request, Depends from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from transformers import pipeline from pydantic import BaseModel from typing import Optional, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig app = FastAPI() # Initialize device def initialize_device(): global device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") initialize_device() # Helper function to read raw request bodies async def parse_raw(request: Request): return await request.body() # Initialize the model and tokenizer with the corrected pre-trained weights def init_corrected_model(): global model_config, model, tokenizer try: model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device) tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf") except Exception as e: print("[WARNING]: Failed to load model and tokenizer conventionally.") print(f"Exception: {e}") model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device) tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf") init_corrected_model() # Utility function to generate answers from the model def miuk_answer(query: str) -> dict[str, str]: query_tokens = tokenizer.encode(query, return_tensors="pt") query_tokens = query_tokens.to(device) answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id) return {"output": tokenizer.decode(answer[:, 0])} # Endpoint handler to receive incoming queries and pass them to the utility function for processing @app.post("/infer_miku") async def infer_endpoint(data: BaseModel = Depends(parse_raw)): input_text = data.raw.decode("utf-8") if input_text is None or len(input_text) == 0: return JSONResponse({"error": "Empty input received."}, status_code=400) result = miuk_answer(input_text) return result @app.get("/infer_miku") def get_default_inference_endpoint(): return {"message": "Use POST method to submit input data"} # Mount static files app.mount("/static", StaticFiles(directory="static"), name="static") # Initialization done print("Initialization Complete.")