import gradio from peft import PeftModel, PeftConfig import re import torch from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import re import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel, PeftConfig # peft_model_id = "/home/afsd721/komt/output_total/checkpoint-1000" ## peft_model_id = "afsd721/onedoit" peft_config = PeftConfig.from_pretrained(peft_model_id) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(peft_model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(peft_config.base_model_name_or_path, quantization_config=bnb_config) model = PeftModel.from_pretrained(model, peft_model_id) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device=device) model.eval() def preprocessing(text): # 문제를 일으킬 수 있는 문자 제거 bad_chars = {"\u200b": "", "…": " ... ", "\ufeff": ""} for bad_char in bad_chars: text = text.replace(bad_char, bad_chars[bad_char]) error_chars = {"\u3000": " ", "\u2009": " ", "\u2002": " ", "\xa0":" "} for error_char in error_chars: text = text.replace(error_char, error_chars[error_char]) # 이메일 제거 text = re.sub(r"[a-zA-Z0-9+-_.]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+", "[이메일]", text).strip() # "#문자" 형식 어절 제거 text = re.sub(r"#\S+", "", text).strip() # "@문자" 형식 어절 제거 text = re.sub(r"@\w+", "", text).strip() # URL 제거 text = re.sub(r"(http|https)?:\/\/\S+\b|www\.(\w+\.)+\S*", "[웹주소]", text).strip() text = re.sub(r"pic\.(\w+\.)+\S*", "[웹주소]", text).strip() # 뉴스 저작권 관련 텍스트 제거 re_patterns = [ r"\<저작권자(\(c\)|ⓒ|©|\(Copyright\)|(\(c\))|(\(C\))).+?\>", r"저작권자\(c\)|ⓒ|©|(Copyright)|(\(c\))|(\(C\))" ] for re_pattern in re_patterns: text = re.sub(re_pattern, "", text).strip() # 뉴스 내 포함된 이미지에 대한 레이블 제거 text = re.sub(r"\(출처 ?= ?.+\) |\(사진 ?= ?.+\) |\(자료 ?= ?.+\)| \(자료사진\) |사진=.+기자 ", "", text).strip() # 문제를 일으킬 수 있는 구두점 치환 punct_mapping = {"‘": "'", "₹": "e", "´": "'", "°": "", "€": "e", "™": "tm", "√": " sqrt ", "×": "x", "²": "2", "—": "-", "–": "-", "’": "'", "_": "-", "`": "'", '“': '"', '”': '"', '“': '"', "£": "e", '∞': 'infinity', 'θ': 'theta', '÷': '/', 'α': 'alpha', '•': '.', 'à': 'a', '−': '-', 'β': 'beta', '∅': '', '³': '3', 'π': 'pi', } for p in punct_mapping: text = text.replace(p, punct_mapping[p]) # 연속된 공백 치환 text = re.sub(r"\s+", " ", text).strip() # 개행 문자 "\n" 제거 text = text.replace('\n', '') # 기타 태그 제거 text = re.sub('<.+?>', '', text, 0, re.I|re.S) return text def my_inference_function(input_text): input_text = preprocessing(input_text) generation_config = GenerationConfig( temperature=0.8, top_p=0.8, top_k=100, max_new_tokens=512, early_stopping=True, do_sample=True, ) q = f"### instruction: {input_text}\n\n### Response: " gened = model.generate( **tokenizer( q, return_tensors='pt', return_token_type_ids=False ).to(device), generation_config=generation_config, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, # streamer=streamer, ) result_str = tokenizer.decode(gened[0]) start_tag = f"\n\n### Response: " start_index = result_str.find(start_tag) if start_index != -1: result_str = result_str[start_index + len(start_tag):].strip() result_str = preprocessing(result_str) return result_str gradio_interface = gradio.Interface( fn = my_inference_function, inputs = "text", outputs = "text" ) gradio_interface.launch()