Model max_seq_length
Should we specifically set model.max_seq_length =512
?
Hi, the model was trained on the data with 512 length. I recommend model.max_seq_length <= 1024
.
@infgrad Thank you very much for your reply! I really appreciate it!
I did a quick test and found that the similarity between query and docs is pretty low.
For example, the Cosine similarity between What is apple?
and Apple is a kind of fruit.
is only 0.24.
from sentence_transformers import SentenceTransformer
query_prompt_name = "s2s_query" # or s2p_query
queries = [
"What is apple?",
"What is the treatment of dementia?",
]
# docs do not need any prompts
docs = [
"Apple is a kind of fruit.",
"Thank you, Representative Waxman, for taking the time to speak with me today about your health policy work. This is Bridget Keene with JAMA.",
]
# !The default dimension is 1024, if you need other dimensions, please clone the model and modify `modules.json`
# to replace `2_Dense_1024` with another dimension, e.g. `2_Dense_256` or `2_Dense_8192` !
model = SentenceTransformer("infgrad/stella_en_1.5B_v5")
query_embeddings = model.encode(queries, prompt_name=query_prompt_name)
doc_embeddings = model.encode(docs)
print(query_embeddings.shape, doc_embeddings.shape)
similarities = model.similarity(query_embeddings, doc_embeddings)
print(similarities)
The result is as follows,
(2, 1024) (2, 1024)
tensor([[0.2433, 0.2636],
[0.2337, 0.2678]])
Could you please help me check if there is any wrong in the codes?
Thank you very much in advance!
Best regards,
Shuyue
July 17th, 2024
Hi, this is really weird!
Here is my codes:
from sentence_transformers import SentenceTransformer
if __name__ == "__main__":
query_prompt_name = "s2s_query" # or s2p_query
queries = [
"What is apple?",
"What is the treatment of dementia?",
]
# docs do not need any prompts
docs = [
"Apple is a kind of fruit.",
"Thank you, Representative Waxman, for taking the time to speak with me today about your health policy work. This is Bridget Keene with JAMA.",
]
# !The default dimension is 1024, if you need other dimensions, please clone the model and modify `modules.json`
# to replace `2_Dense_1024` with another dimension, e.g. `2_Dense_256` or `2_Dense_8192` !
model = SentenceTransformer("MODEL_PATH", trust_remote_code=True)
query_embeddings = model.encode(queries, prompt_name=query_prompt_name)
doc_embeddings = model.encode(docs)
print(query_embeddings.shape, doc_embeddings.shape)
similarities = model.similarity(query_embeddings, doc_embeddings)
print(similarities)
The output is:
(2, 1024) (2, 1024)
tensor([[0.6056, 0.1705],
[0.1453, 0.1893]])
Please update your model and try again.
If you the program has any warnings, please let me know.
Best regards,
@shuyuej
Here is my environment:
accelerate 0.31.0
aiofiles 23.2.1
aiohttp 3.9.5
aiosignal 1.3.1
altair 5.2.0
annotated-types 0.6.0
anyio 4.3.0
asttokens 2.0.5
async-timeout 4.0.3
attrs 23.2.0
beautifulsoup4 4.12.3
beir 2.0.0
bitsandbytes 0.43.0
cachetools 5.3.3
certifi 2024.2.2
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
cmake 3.29.0.1
colorama 0.4.6
comm 0.2.2
contourpy 1.2.0
cupy-cuda12x 12.1.0
cycler 0.12.1
datasets 2.20.0
debugpy 1.6.7
decorator 5.1.1
deepspeed 0.14.4
dill 0.3.8
diskcache 5.6.3
distro 1.9.0
dnspython 2.6.1
docker-pycreds 0.4.0
einops 0.8.0
elasticsearch 7.9.1
et-xmlfile 1.1.0
eval_type_backport 0.2.0
exceptiongroup 1.2.0
executing 0.8.3
faiss-cpu 1.8.0
fastapi 0.110.0
fastrlock 0.8.2
ffmpy 0.3.2
filelock 3.13.1
flash-attn 2.5.9.post1
fonttools 4.49.0
frozenlist 1.4.1
fsspec 2024.2.0
gitdb 4.0.11
GitPython 3.1.43
google 3.0.0
GPUtil 1.4.0
gradio 4.36.1
gradio_client 1.0.1
h11 0.14.0
hjson 3.1.0
httpcore 1.0.4
httptools 0.6.1
httpx 0.27.0
huggingface-hub 0.23.4
idna 3.6
importlib_metadata 7.1.0
importlib_resources 6.3.0
interegular 0.3.3
ipykernel 6.29.3
ipython 8.20.0
jedi 0.18.1
jieba 0.42.1
Jinja2 3.1.3
jiojio 1.2.5
jionlp 1.5.7
joblib 1.3.2
jsonlines 4.0.0
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
jupyter_client 8.6.1
jupyter_core 5.7.2
kiwisolver 1.4.5
lark 1.1.9
llvmlite 0.42.0
lm-format-enforcer 0.10.1
loguru 0.7.2
markdown-it-py 3.0.0
markdown_to_json 2.1.1
MarkupSafe 2.1.5
matplotlib 3.8.3
matplotlib-inline 0.1.6
mdurl 0.1.2
mpmath 1.3.0
msgpack 1.0.8
mteb 1.12.79
multidict 6.0.5
multiprocess 0.70.16
nest_asyncio 1.6.0
networkx 3.2.1
ninja 1.11.1.1
numba 0.59.0
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-ml-py 12.550.52
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.4.99
nvidia-nvtx-cu12 12.1.105
openai 1.14.0
openpyxl 3.1.2
orjson 3.9.15
outlines 0.0.46
packaging 24.0
pandas 2.2.1
parso 0.8.3
peft 0.9.0
pexpect 4.8.0
pillow 10.2.0
pip 23.3.1
platformdirs 4.2.0
polars 0.20.31
prometheus_client 0.20.0
prometheus-fastapi-instrumentator 7.0.0
prompt-toolkit 3.0.43
protobuf 5.26.0
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
py 1.11.0
py-cpuinfo 9.0.0
pyairports 2.1.1
pyarrow 16.1.0
pyarrow-hotfix 0.6
pycountry 24.6.1
pycryptodome 3.9.9
pydantic 2.6.4
pydantic_core 2.16.3
pydub 0.25.1
Pygments 2.15.1
PyJWT 2.8.0
pymongo 4.8.0
pynvml 11.5.0
pyparsing 3.1.2
python-dateutil 2.9.0
python-dotenv 1.0.1
python-multipart 0.0.9
pytrec-eval 0.5
pytrec-eval-terrier 0.5.6
pytz 2020.5
PyYAML 6.0.1
pyzmq 25.1.2
ray 2.9.3
referencing 0.33.0
regex 2023.12.25
requests 2.32.3
retry 0.9.2
rich 13.7.1
rjieba 0.1.11
roformer 0.4.3
rpds-py 0.18.0
ruff 0.3.2
safetensors 0.4.2
scikit-learn 1.4.1.post1
scipy 1.12.0
semantic-version 2.10.0
sentence-transformers 3.0.1
sentencepiece 0.2.0
sentry-sdk 2.7.1
setproctitle 1.3.3
setuptools 68.2.2
shellingham 1.5.4
six 1.16.0
smmap 5.0.1
sniffio 1.3.1
soupsieve 2.5
stack-data 0.2.0
starlette 0.36.3
sympy 1.12
threadpoolctl 3.3.0
tiktoken 0.6.0
tokenizers 0.19.1
tomlkit 0.12.0
toolz 0.12.1
torch 2.3.0
torchvision 0.18.0
tornado 6.4
tqdm 4.66.4
traitlets 5.7.1
transformers 4.42.3
triton 2.3.0
typer 0.12.3
typing_extensions 4.10.0
tzdata 2024.1
urllib3 2.2.1
uvicorn 0.28.0
uvloop 0.19.0
vllm 0.5.1
vllm-flash-attn 2.5.9
vllm-nccl-cu12 2.18.1.0.1.0
volcengine 1.0.133
volcengine-python-sdk 1.0.86
wandb 0.17.3
watchfiles 0.21.0
wcwidth 0.2.5
websockets 11.0.3
wheel 0.41.2
xformers 0.0.26.post1
xxhash 3.4.1
yarl 1.9.4
zhconv 1.4.3
zhipuai 2.0.1
zipfile36 0.1.3
zipp 3.17.0
Did you figure out what the issue was? I'm pretty curious.
- Tom Aarsen