Tìm kiếm ngữ nghĩa với FAISS
Trong phần 5, chúng ta đã tạo tập dữ liệu về các vấn đề GitHub và nhận xét từ kho lưu trữ 🤗 Datasets. Trong phần này, chúng ta sẽ sử dụng thông tin này để xây dựng một công cụ tìm kiếm có thể giúp ta tìm câu trả lời cho những câu hỏi cấp bách nhất về thư viện!
Sử dụng nhúng biểu diễn từ cho tìm kiếm ngữ nghĩa
Như chúng ta đã thấy trong Chương 1, các mô hình ngôn ngữ dựa trên Transformer đại diện cho mỗi token trong một khoảng văn bản dưới dạng một vector nhugns biểu diễn từ. Hóa ra người ta có thể “gộp” các biểu diễn riêng lẻ để tạo biểu diễn vectơ cho toàn bộ câu, đoạn văn hoặc toàn bộ (trong một số trường hợp) tài liệu. Sau đó, các phép nhúng này có thể được sử dụng để tìm các tài liệu tương tự trong kho tài liệu bằng cách tính toán độ tương tự của sản phẩm (hoặc một số chỉ số tương tự khác) giữa mỗi biễu diễn và trả về các tài liệu có độ tương đồng lớn nhất.
Trong phần này, chúng ta sẽ sử dụng các biểu diễn từ để phát triển một công cụ tìm kiếm ngữ nghĩa. Các công cụ tìm kiếm này cung cấp một số lợi thế so với các phương pháp tiếp cận thông thường dựa trên việc kết hợp các từ khóa trong một truy vấn với các tài liệu.
Tải và chuẩn bị tập dữ liệu
Điều đầu tiên chúng ta cần làm là tải xuống tập dữ liệu về các sự cố GitHub, vì vậy hãy sử dụng thư viện 🤗 Hub để giải quyết URL nơi tệp của chúng ta được lưu trữ trên Hugging Face Hub:
from huggingface_hub import hf_hub_url
data_files = hf_hub_url(
repo_id="lewtun/github-issues",
filename="datasets-issues-with-comments.jsonl",
repo_type="dataset",
)
Với URL được lưu trữ trong data_files
, sau đó chúng ta có thể tải tập dữ liệu từ xa bằng phương pháp đã được giới thiệu trong phần 2:
from datasets import load_dataset
issues_dataset = load_dataset("json", data_files=data_files, split="train")
issues_dataset
Dataset({
features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
num_rows: 2855
})
Ở đây chúng ta đã chỉ định tách train
mặc định trong load_dataset()
, vì vậy nó trả về một Dataset
thay vì DatasetDict
. Trình tự đầu tiên của doanh nghiệp là lọc ra các yêu cầu kéo, vì những yêu cầu này có xu hướng hiếm khi được sử dụng để trả lời các truy vấn của người dùng và sẽ tạo ra nhiễu trong công cụ tìm kiếm mình. Chúng ta có thể sử dụng hàm Dataset.filter()
đã quen thuộc với bạn để loại trừ các hàng này trong tập dữ liệu của mình. Cùng lúc đó, hãy cùng lọc ra các hàng không có nhận xét, vì những hàng này không cung cấp câu trả lời cho các truy vấn của người dùng:
issues_dataset = issues_dataset.filter(
lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
num_rows: 771
})
Chúng ta có thể thấy rằng có rất nhiều cột trong tập dữ liệu của chúng ta, hầu hết trong số đó chúng ta không cần phải xây dựng công cụ tìm kiếm của mình. Từ góc độ tìm kiếm, các cột chứa nhiều thông tin nhất là title
, body
, và comments
, trong khi html_url
cung cấp cho chúng ta một liên kết trỏ về nguồn. Hãy sử dụng hàm Dataset.remove_columns()
để xóa phần còn lại:
columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body'],
num_rows: 771
})
Để tạo các biểu diễn, chúng ta sẽ bổ sung mỗi nhận xét với tiêu đề và nội dung của vấn đề, vì các trường này thường bao gồm thông tin ngữ cảnh hữu ích. Vì cột comments
của hiện là danh sách các nhận xét cho từng vấn đề, chúng tôi cần khám phá các cột để mỗi hàng bao gồm một tuple (html_url, title, body, comment)
. Trong Pandas, chúng ta có thể thực hiện việc này bằng hàm hàm DataFrame.explode()
, tạo một hàng mới cho mỗi phần tử trong cột giống như danh sách, trong khi sao chép tất cả các giá trị cột khác. Để xem điều này hoạt động, trước tiên chúng ta hãy chúng chuyển thành định dạng Pandas DataFrame
:
issues_dataset.set_format("pandas")
df = issues_dataset[:]
Nếu ta kiểm tra hàng đầu tiên trong DataFrame
này, chúng ta có thể thấy có bốn nhận xét liên quan đến vấn đề này:
df["comments"][0].tolist()
['the bug code locate in :\r\n if data_args.task_name is not None:\r\n # Downloading and loading a dataset from the hub.\r\n datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
'cannot connect,even by Web browser,please check that there is some problems。',
'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']
Khi chúng ta khám phá df
, chúng tôi mong đợi nhận được một hàng cho mỗi nhận xét này. Hãy kiểm tra xem nó đã đúng chưa:
comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url | title | comments | body | |
---|---|---|---|---|
0 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | the bug code locate in :\r\n if data_args.task_name is not None... | Hello,\r\nI am trying to run run_glue.py and it gives me this error... |
1 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com... | Hello,\r\nI am trying to run run_glue.py and it gives me this error... |
2 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | cannot connect,even by Web browser,please check that there is some problems。 | Hello,\r\nI am trying to run run_glue.py and it gives me this error... |
3 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem... | Hello,\r\nI am trying to run run_glue.py and it gives me this error... |
Tuyệt vời, chúng ta có thể thấy các hàng đã được nhân rộng, với cột comments
chứa các nhận xét riêng lẻ! Bây giờ chúng ta đã hoàn thành với Pandas, chúng ta có thể nhanh chóng chuyển trở lại Dataset
bằng cách tải DataFrame
vào bộ nhớ:
from datasets import Dataset
comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body'],
num_rows: 2842
})
Được rồi, điều này đã cho chúng ta vài nghìn nhận xét để làm việc cùng!
✏️ Thử nghiệm thôi! Cùng xem liệu bạn có thể sử dụng Dataset.map()
để khám phá cột comments
của issues_dataset
mà không cần sử dụng Pandas hay không. Nó sẽ hơi khó khăn một chút; bạn có thể xem phần “Batch mapping” của tài liệu 🤗 Datasets, một tài liệu hữu ích cho tác vụ này.
Bây giờ chúng ta đã có một nhận xét trên mỗi hàng, hãy tạo một cột comments_length
mới chứa số từ trên mỗi nhận xét:
comments_dataset = comments_dataset.map(
lambda x: {"comment_length": len(x["comments"].split())}
)
Chúng ta có thể sử dụng cột mới này để lọc ra các nhận xét ngắn, thường bao gồm những thứ như “cc @lewtun” hoặc “Thanks!” không liên quan đến công cụ tìm kiếm của mình. Không có con số chính xác để chọn cho bộ lọc, nhưng khoảng 15 từ có vẻ như là một khởi đầu tốt:
comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
num_rows: 2098
})
Sau khi dọn dẹp tập dữ liệu một chút, hãy ghép tiêu đề, mô tả và nhận xét của vấn đề với nhau trong một cột text
mới. Như thường lệ, chúng ta sẽ viết một hàm đơn giản mà chúng ta có thể truyền vào Dataset.map()
:
def concatenate_text(examples):
return {
"text": examples["title"]
+ " \n "
+ examples["body"]
+ " \n "
+ examples["comments"]
}
comments_dataset = comments_dataset.map(concatenate_text)
Cuối cùng thì chúng ta cũng đã sẵn sàng để tạo một số biểu điễn! Chúng ta hãy xem nào.
Tạo ra biểu diễn văn bản
Chúng ta đã thấy trong Chương 2 rằng ta có thể nhận được token biễu diễn nhúng bằng cách sử dụng lớp AutoModel
. Tất cả những gì chúng ta cần làm là chọn một checkpoint phù hợp để tải mô hình từ đó. May mắn thay, có một thư viện tên là sentence-transformers
dành riêng cho việc tạo các biểu diễn này. Như được mô tả trong tài liệu của thư viện, trường hợp sử dụng của ta là một ví dụ về tìm kiếm ngữ nghĩa phi đối xứng bởi vì chúng ta có một truy vấn ngắn có câu trả lời ta muốn tìm thấy trong một tài liệu lại dài hơn nhiều, chẳng hạn như một nhận xét về vấn đề. Bảng tổng quan về mô hình trong phần tài liệu chỉ ra rằng checkpoint multi-qa-mpnet-base-dot-v1
có hiệu suất tốt nhất cho tìm kiếm ngữ nghĩa, vì vậy chúng ta sẽ sử dụng nó cho ứng dụng của mình. Chúng ta cũng sẽ tải tokenizer bằng cách sử dụng cùng một checkpoint:
from transformers import AutoTokenizer, AutoModel
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
Để tăng tốc quá trình biểu diễn, ta sẽ giúp đặt mô hình và đầu vào trên thiết bị GPU, vì vậy hãy làm điều đó ngay bây giờ thôi:
import torch
device = torch.device("cuda")
model.to(device)
Như đã đề cập trước đó, chúng ta muốn biểu diễn mỗi mục trong kho dữ liệu các vấn đề GitHub của mình dưới dạng một vectơ duy nhất, vì vậy chúng ta cần “gộp” hoặc tính trung bình các lần biễu diễn token theo một cách nào đó. Một cách tiếp cận phổ biến là thực hiện CLS pooling trên đầu ra của mô hình, nơi ta chỉ cần thu thập trạng thái ẩn cuối cùng cho token đặc biệt [CLS]
. Hàm sau thực hiện thủ thuật này cho chúng ta:
def cls_pooling(model_output):
return model_output.last_hidden_state[:, 0]
Tiếp theo, chúng tôi sẽ tạo một chức năng trợ giúp sẽ tokanize danh sách các tài liệu, đặt các tensor trên GPU, đưa chúng vào mô hình và cuối cùng áp dụng CLS gộp cho các đầu ra:
def get_embeddings(text_list):
encoded_input = tokenizer(
text_list, padding=True, truncation=True, return_tensors="pt"
)
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
model_output = model(**encoded_input)
return cls_pooling(model_output)
Chúng ta có thể kiểm tra chức năng hoạt động bằng cách cung cấp cho nó đoạn văn bản đầu tiên trong kho tài liệu và kiểm tra hình dạng đầu ra:
embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])
Tuyệt vời, chúng ta đã chuyển đổi mục nhập đầu tiên trong kho tài liệu của mình thành một vectơ 768 chiều! Chúng ta có thể sử dụng Dataset.map()
để áp dụng hàm get_embeddings()
cho mỗi hàng trong kho tài liệu của mình, vì vậy hãy tạo một cột embeddings
mới như sau:
embeddings_dataset = comments_dataset.map(
lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)
Lưu ý rằng chúng ta đã chuyển đổi các biểu diễn sang thành mảng NumPy - đó là vì 🤗 Datasets yêu cầu định dạng này khi ta cố gắng lập chỉ mục chúng bằng FAISS, điều mà ta sẽ thực hiện tiếp theo.
Sử dụng FAISS để tìm kiếm điểm tương đồng hiệu quả
Bây giờ chúng ta đã có một tập dữ liệu về các biểu diễn, chúng ta cần một số cách để tìm kiếm chúng. Để làm điều này, chúng ta sẽ sử dụng một cấu trúc dữ liệu đặc biệt trong 🤗 Datasets được gọi là FAISS index. FAISS (viết tắt của Facebook AI Similarity Search) là một thư viện cung cấp các thuật toán hiệu quả để nhanh chóng tìm kiếm và phân cụm các vectơ nhúng biểu diễn.
Ý tưởng cơ bản đằng sau FAISS là tạo ra một cấu trúc dữ liệu đặc biệt được gọi là index hay chỉ mục cho phép người ta tìm thấy các biểu diễn nhúng nào tương tự như biểu diễn nhúng đầu vào. Tạo chỉ mục FAISS trong 🤗 Datasets rất đơn giản - ta sử dụng hàm Dataset.add_faiss_index()
và chỉ định cột nào trong tập dữ liệu mà ta muốn lập chỉ mục:
embeddings_dataset.add_faiss_index(column="embeddings")
Bây giờ chúng ta có thể thực hiện các truy vấn trên chỉ mục này bằng cách thực hiện tra cứu những mẫu lân cận nhất thông qua hàm Dataset.get_nearest_examples()
. Hãy kiểm tra điều này bằng cách biểu diễn một câu hỏi như sau:
question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])
Cũng giống như các tài liệu, giờ đây chúng ta có một vectơ 768 chiều đại diện cho truy vấn, mà chúng ta có thể so sánh với toàn bộ kho dữ liệu để tìm ra các cách biểu diễn tương tự nhất:
scores, samples = embeddings_dataset.get_nearest_examples(
"embeddings", question_embedding, k=5
)
Hàm Dataset.get_nearest_examples()
trả về một loạt điểm xếp hạng sự tương đồng giữa truy vấn và tài liệu và một tập hợp các mẫu tương ứng (ở đây, là 5 kết quả phù hợp nhất). Hãy thu thập những thứ này vào một pandas.DataFrame
để chúng ta có thể dễ dàng sắp xếp chúng:
import pandas as pd
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
Bây giờ chúng ta có thể lặp lại một vài hàng đầu tiên để xem truy vấn của chúng ta khớp với các nhận xét có sẵn như thế nào:
for _, row in samples_df.iterrows():
print(f"COMMENT: {row.comments}")
print(f"SCORE: {row.scores}")
print(f"TITLE: {row.title}")
print(f"URL: {row.html_url}")
print("=" * 50)
print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.
@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`
We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.
Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)
I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.
----------
> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.
----------
About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(...)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.
SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`
HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""
Không tệ! Lần truy cập thứ hai của chúng ta dường như phù hợp với truy vấn.
✏️ Thử nghiệm thôi! Tạo truy vấn của riêng bạn và xem liệu bạn có thể tìm thấy câu trả lời trong các tài liệu đã truy xuất hay không. Bạn có thể phải tăng tham số k
trong Dataset.get_nearest_examples()
để mở rộng tìm kiếm.