Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import spaces | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import List, Dict, Any | |
import time | |
# 创建 FastAPI 应用 | |
app = FastAPI() | |
# 配置 CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 加载模型和分词器 | |
model_name = "BAAI/bge-m3" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
model.eval() | |
class EmbeddingRequest(BaseModel): | |
input: List[str] | str | |
model: str | None = model_name | |
encoding_format: str | None = "float" | |
user: str | None = None | |
class EmbeddingResponse(BaseModel): | |
object: str = "list" | |
data: List[Dict[str, Any]] | |
model: str | |
usage: Dict[str, int] | |
def get_embedding(text: str) -> List[float]: | |
inputs = tokenizer( | |
text, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
).to(model.device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
return embeddings[0].tolist() | |
def process_embeddings(request_dict: dict) -> dict: | |
"""非异步函数处理嵌入向量""" | |
input_texts = [request_dict["input"]] if isinstance(request_dict["input"], str) else request_dict["input"] | |
embeddings = [] | |
total_tokens = 0 | |
for text in input_texts: | |
tokens = tokenizer.encode(text) | |
total_tokens += len(tokens) | |
embedding = get_embedding(text) | |
embeddings.append({ | |
"object": "embedding", | |
"embedding": embedding, | |
"index": len(embeddings) | |
}) | |
return { | |
"object": "list", | |
"data": embeddings, | |
"model": request_dict.get("model", model_name), | |
"usage": { | |
"prompt_tokens": total_tokens, | |
"total_tokens": total_tokens | |
} | |
} | |
async def create_embeddings(request: EmbeddingRequest): | |
"""异步API端点""" | |
result = process_embeddings(request.dict()) | |
return result | |
def gradio_embedding(text: str) -> Dict: | |
"""Gradio接口函数""" | |
request_dict = { | |
"input": text, | |
"model": model_name | |
} | |
return process_embeddings(request_dict) | |
# 创建 Gradio 界面 | |
demo = gr.Interface( | |
fn=gradio_embedding, | |
inputs=gr.Textbox(lines=3, placeholder="输入要进行编码的文本..."), | |
outputs=gr.Json(), | |
title="BGE-M3 Embeddings (OpenAI 兼容格式)", | |
description="输入文本,获取其对应的嵌入向量,返回格式与 OpenAI API 兼容。", | |
examples=[ | |
["这是一个示例文本。"], | |
["人工智能正在改变世界。"] | |
] | |
) | |
# 挂载 Gradio 应用到 FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |