Edit model card

English | 中文

Code implementation of Qwen2 based embeddings

This model code is for Qwen2 based embedding models.

We enable the bidirectional attention by default.

Usage

  1. Download the configuration.py and modeling.py to your saved gte-Qwen2 model directory.
  2. Replace the modeling_qwen. with modeling. in auto_map field of config.json.

Recommendation: Enable Unpadding and Acceleration with xformers

This code supports the acceleration of attention computations using xformers, which can automatically choose the optimal implementation based on the type of device, such as flash_attn. Therefore, we can also achieve significant acceleration on old devices like the V100.

Firstly, install xformers (with pytorch pre-installed):

if pytorch is installed using conda:
    conda install xformers -c xformers
elif pytorch is installed using pip:
    # cuda 11.8 version
    pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
    # cuda 12.1 version
    pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121

For more information, refer to Installing xformers.

Then, when loading the model, set unpad_inputs and use_memory_efficient_attention to true, and set torch_dtype to torch.float16 (or torch.bfloat16) to achieve the acceleration.

import torch
from transformers import AutoModel, AutoTokenizer

path = 'Alibaba-NLP/gte-Qwen2-1.5B-instruct'
device = torch.device('cuda')
tokenzier = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModel.from_pretrained(
    path,
    trust_remote_code=True,
    unpad_inputs=True,
    use_memory_efficient_attention=True,
    torch_dtype=torch.float16
).to(device)

inputs = tokenzier(['test input'], truncation=True, max_length=8192, padding=True, return_tensors='pt')

with torch.inference_mode():
    outputs = model(**inputs.to(device))

Alternatively, you can directly modify the unpad_inputs and use_memory_efficient_attention settings to true in the model's config.json, eliminating the need to set them in the code.

Citation

@misc{zhang2024mgte,
  title={mGTE: Generalized Long-Context Text Representation and Reranking Models for Multilingual Text Retrieval}, 
  author={Xin Zhang and Yanzhao Zhang and Dingkun Long and Wen Xie and Ziqi Dai and Jialong Tang and Huan Lin and Baosong Yang and Pengjun Xie and Fei Huang and Meishan Zhang and Wenjie Li and Min Zhang},
  year={2024},
  eprint={2407.19669},
  archivePrefix={arXiv},
  primaryClass={cs.CL},
  url={https://arxiv.org/abs/2407.19669}, 
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .