Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
-
from fastapi import FastAPI
|
4 |
from fastapi.responses import StreamingResponse
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
8 |
pipeline,
|
9 |
-
AutoModelForCausalLM
|
10 |
AutoTokenizer,
|
11 |
GenerationConfig,
|
12 |
StoppingCriteriaList
|
@@ -69,7 +69,7 @@ class S3ModelLoader:
|
|
69 |
s3_uri = self._get_s3_uri(model_name)
|
70 |
try:
|
71 |
config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
|
72 |
-
model =
|
73 |
tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
|
74 |
|
75 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
@@ -80,7 +80,7 @@ class S3ModelLoader:
|
|
80 |
try:
|
81 |
config = AutoConfig.from_pretrained(model_name)
|
82 |
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
|
83 |
-
model =
|
84 |
|
85 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
86 |
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
@@ -135,7 +135,6 @@ async def generate(request: GenerateRequest):
|
|
135 |
raise HTTPException(status_code=500,
|
136 |
detail=f"Internal server error: {str(e)}")
|
137 |
|
138 |
-
|
139 |
async def stream_text(model, tokenizer, input_text,
|
140 |
generation_config, stop_sequences,
|
141 |
device, chunk_delay, max_length=2048):
|
@@ -199,8 +198,6 @@ async def stream_text(model, tokenizer, input_text,
|
|
199 |
truncation=True,
|
200 |
max_length=max_length).to(device)
|
201 |
|
202 |
-
|
203 |
-
|
204 |
@app.post("/generate-image")
|
205 |
async def generate_image(request: GenerateRequest):
|
206 |
try:
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
+
from fastapi import FastAPI, HTTPException
|
4 |
from fastapi.responses import StreamingResponse
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
8 |
pipeline,
|
9 |
+
AutoModelForSeq2SeqLM, # Changed AutoModelForCausalLM to AutoModelForSeq2SeqLM
|
10 |
AutoTokenizer,
|
11 |
GenerationConfig,
|
12 |
StoppingCriteriaList
|
|
|
69 |
s3_uri = self._get_s3_uri(model_name)
|
70 |
try:
|
71 |
config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
|
72 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(s3_uri, config=config, local_files_only=True) # Changed AutoModelForCausalLM
|
73 |
tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
|
74 |
|
75 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
|
|
80 |
try:
|
81 |
config = AutoConfig.from_pretrained(model_name)
|
82 |
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
|
83 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config) # Changed AutoModelForCausalLM
|
84 |
|
85 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
86 |
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
|
|
135 |
raise HTTPException(status_code=500,
|
136 |
detail=f"Internal server error: {str(e)}")
|
137 |
|
|
|
138 |
async def stream_text(model, tokenizer, input_text,
|
139 |
generation_config, stop_sequences,
|
140 |
device, chunk_delay, max_length=2048):
|
|
|
198 |
truncation=True,
|
199 |
max_length=max_length).to(device)
|
200 |
|
|
|
|
|
201 |
@app.post("/generate-image")
|
202 |
async def generate_image(request: GenerateRequest):
|
203 |
try:
|