์ถํ ViT๋ชจ๋ธ๋ก ๊ต์ฒด๋ ์์ ์ ๋๋ค.
https://huggingface.co/gihakkk/vit_modle
์ ๋ชจ๋ธ๋ก ๋์ฒด๋์์ต๋๋ค
์ ๋ชจ๋ธ์ด ์ฑ๋ฅ ๋ฉด์์ ํจ์ ์๋์ ์
๋๋ค.
๋ฐ๋ผ์, ๋์ด์ CNN๋ชจ๋ธ์ ์ฌ์ฉํ์ง ์์ต๋๋ค.
ํ๊น
ํ์ด์ค์์ ๋ฐ์ ์ฌ์ฉํ ์ ์๋๋ก config ํ์ผ์ ๋ง๋ค์์ผ๋ฉฐ, ์ ์๋ํ๋๊ฒ์ ํ์ธํ์ต๋๋ค.
๋ก๋งจ์ค ์ค์บ ์์ ์นด๋ฉ๋ผ์ ์ฃผ๋ก ๋น์ถฐ์ง ์ฌ์ง์ ์กฐ์ฌํด ์ง์ด๋ฃ์ด ์ ์ฌ๋ ํ๋จ์ ์ฌ์ฉํ ์ ์๋๋ก ๊ตฌ์ฑํ์ต๋๋ค.
๋ชจ๋ธ ํ์ต์ ํ์ฉ๋ ๋ฐ์ดํฐ๋ ์ ๋ถ ์ง์ ๊ณต์ํ ๊ฒ์ด๋ฉฐ, ๋ก๋งจ์ค์ค์บ , ๋ชธ์บ ํผ์ฑ ํผํด๋ฅผ ์
์ ์ง์ธ์ ํตํด ์ป์ ๋ฐ์ดํฐ์, ๊ทธ ํน์ง์ ํตํด ์ถ๊ฐ๋ก ๋ง๋ ๋ฐ์ดํฐ๋ฅผ ํตํด ๋ง๋ค์์ต๋๋ค.
11์ ์ด์ ์
๋ก๋ ํ ์๋ก์ด ๋ฒ์ ผ์ ๋ฐ์ดํฐ๋ฅผ ์ป๊ธฐ ์ํด ์ง์ ๋ชธ์บ ํผ์ฑ์ ๋นํด ์๋ก์ด ๋ฐ์ดํฐ๋ฅผ ์ป๊ณ , ์ด๋ฅผ ๋ชจ๋ธํ์ตํ์ฌ ์ฌ๋ฆด๊ฒ ์
๋๋ค.
๋ค์๊ณผ ๊ฐ์ด ์ฌ์ฉํ ์ ์์ต๋๋ค
import tensorflow as tf
import numpy as np
from PIL import Image
import requests
# CNN ๋ชจ๋ธ ๋ค์ด๋ก๋ ๋ฐ ๋ก๋
model_url = "https://huggingface.co/gihakkk/CNN_modle/resolve/main/cnn_similarity_model.keras"
model_path = "cnn_similarity_model.keras"
# ๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋
response = requests.get(model_url)
with open(model_path, "wb") as f:
f.write(response.content)
# Keras ๋ชจ๋ธ ๋ก๋
cnn_model = tf.keras.models.load_model(model_path)
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํจ์
def preprocess_image(image_path):
try:
img = Image.open(image_path).convert('RGB')
img = img.resize((152, 152)) # ๋ชจ๋ธ์ด ์๊ตฌํ๋ ํฌ๊ธฐ
img_array = np.array(img) / 255.0 # ์ด๋ฏธ์ง๋ฅผ 0-1 ์ฌ์ด๋ก ์ ๊ทํ
img_array = np.expand_dims(img_array, axis=0) # ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ
return img_array
except Exception as e:
print(f"Error processing image: {e}")
return None
# ์ ์ฌ๋ ์์ธก ํจ์
def predict_similarity(image_path):
img_array = preprocess_image(image_path)
if img_array is not None:
predictions = cnn_model.predict(img_array) # ๋ชจ๋ธ์ ํตํด ์์ธก
similarity_score = np.mean(predictions) # ์ ์ฌ๋ ์ ์์ ํ๊ท ๊ณ์ฐ
if similarity_score > 0.5: # ์๊ณ๊ฐ์ ๊ธฐ์ค์ผ๋ก ์ ์ฌ๋ ํ๋จ
return "๋ก๋งจ์ค ์ค์บ ์ด๋ฏธ์ง์
๋๋ค."
else:
return "๋ก๋งจ์ค ์ค์บ ์ด๋ฏธ์ง๊ฐ ์๋๋๋ค."
else:
return "์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ์ ์คํจํ์ต๋๋ค."
# ํ
์คํธ ์ด๋ฏธ์ง ์์ธก
image_path = r'์ฌ์ง ์์น ์
๋ ฅ' # ํ
์คํธํ ์ด๋ฏธ์ง ๊ฒฝ๋ก
result = predict_similarity(image_path)
print(result)
- Downloads last month
- 95