Spaces:
Paused
Paused
from typing import Optional | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from vllm import LLM, SamplingParams, RequestOutput | |
# Don't forget to set HF_TOKEN in the env during running | |
app = FastAPI() | |
# Initialize the LLM engine | |
# Replace 'your-model-path' with the actual path or name of your model | |
engine = LLM( | |
model='meta-llama/Llama-3.2-3B-Instruct', | |
revision="0cb88a4f764b7a12671c53f0838cd831a0843b95", | |
max_num_batched_tokens=512, # Reduced for T4 | |
max_num_seqs=16, # Reduced for T4 | |
gpu_memory_utilization=0.85, # Slightly increased, adjust if needed | |
max_model_len=131072, # Llama-3.2-3B-Instruct context length | |
enforce_eager=True, # Disable CUDA graph | |
dtype='half', # Use half precision | |
) | |
def greet_json(): | |
return {"Hello": "World!"} | |
class GenerationRequest(BaseModel): | |
prompt: str | |
max_tokens: int = 100 | |
temperature: float = 0.7 | |
logit_bias: Optional[dict[int, float]] = None | |
class GenerationResponse(BaseModel): | |
text: Optional[str] | |
error: Optional[str] | |
def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]: | |
try: | |
sampling_params: SamplingParams = SamplingParams( | |
temperature=request.temperature, | |
max_tokens=request.max_tokens, | |
logit_bias=request.logit_bias, | |
) | |
# Generate text | |
return engine.generate( | |
prompts=request.prompt, | |
sampling_params=sampling_params | |
) | |
except Exception as e: | |
return { | |
"error": str(e) | |
} | |