Hjgugugjhuhjggg commited on
Commit
9de7b93
·
verified ·
1 Parent(s): b5fcdec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
  import torch
3
- from fastapi import FastAPI
4
  from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
8
  pipeline,
9
- AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteriaList
@@ -69,7 +69,7 @@ class S3ModelLoader:
69
  s3_uri = self._get_s3_uri(model_name)
70
  try:
71
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
72
- model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
73
  tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
74
 
75
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
@@ -80,7 +80,7 @@ class S3ModelLoader:
80
  try:
81
  config = AutoConfig.from_pretrained(model_name)
82
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
83
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
84
 
85
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
86
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
@@ -135,7 +135,6 @@ async def generate(request: GenerateRequest):
135
  raise HTTPException(status_code=500,
136
  detail=f"Internal server error: {str(e)}")
137
 
138
-
139
  async def stream_text(model, tokenizer, input_text,
140
  generation_config, stop_sequences,
141
  device, chunk_delay, max_length=2048):
@@ -199,8 +198,6 @@ async def stream_text(model, tokenizer, input_text,
199
  truncation=True,
200
  max_length=max_length).to(device)
201
 
202
-
203
-
204
  @app.post("/generate-image")
205
  async def generate_image(request: GenerateRequest):
206
  try:
 
1
  import os
2
  import torch
3
+ from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
8
  pipeline,
9
+ AutoModelForSeq2SeqLM, # Changed AutoModelForCausalLM to AutoModelForSeq2SeqLM
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteriaList
 
69
  s3_uri = self._get_s3_uri(model_name)
70
  try:
71
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
72
+ model = AutoModelForSeq2SeqLM.from_pretrained(s3_uri, config=config, local_files_only=True) # Changed AutoModelForCausalLM
73
  tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
74
 
75
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
 
80
  try:
81
  config = AutoConfig.from_pretrained(model_name)
82
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
83
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config) # Changed AutoModelForCausalLM
84
 
85
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
86
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
135
  raise HTTPException(status_code=500,
136
  detail=f"Internal server error: {str(e)}")
137
 
 
138
  async def stream_text(model, tokenizer, input_text,
139
  generation_config, stop_sequences,
140
  device, chunk_delay, max_length=2048):
 
198
  truncation=True,
199
  max_length=max_length).to(device)
200
 
 
 
201
  @app.post("/generate-image")
202
  async def generate_image(request: GenerateRequest):
203
  try: