【阿里开源】Qwen3-Embedding和Qwen3-Reranker

使用Transformer运行Qwen3-Reranker模型

main_reranker.py 代码如下

import torch  
from transformers import AutoTokenizer, AutoModelForCausalLM  
from fastapi import FastAPI  
from pydantic import BaseModel  
from typing import List, Optional  

# ---------------------  
# 初始化模型与参数  
# ---------------------  
MODEL_PATH = "/gemini/pretrain/Qwen3-Reranker-0.6B"  # 换成本地模型路径或 huggingface hub 名字  

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side='left')  
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).eval()  

# 若显卡和环境允许可用下面1行启用GPU和flash_attention  
# model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval()  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
model = model.to(device)  

token_false_id = tokenizer.convert_tokens_to_ids("no")  
token_true_id = tokenizer.convert_tokens_to_ids("yes")  
max_length = 8192  

prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"  
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"  
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)  
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)  

def format_instruction(instruction, query, doc):  
    if instruction is None:  
        instruction = 'Given a web search query, retrieve relevant passages that answer the query'  
    output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(instruction=instruction, query=query, doc=doc)  
    return output  

def process_inputs(pairs):  
    # pairs: List[str]  
    inputs = tokenizer(  
        pairs, padding=False, truncation='longest_first',  
        return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)  
    )  
    for i, ele in enumerate(inputs['input_ids']):  
        inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens  
    inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)  
    for key in inputs:  
        inputs[key] = inputs[key].to(model.device)  
    return inputs  

@torch.no_grad()  
def compute_logits(inputs):  
    batch_scores = model(**inputs).logits[:, -1, :]  
    true_vector = batch_scores[:, token_true_id]  
    false_vector = batch_scores[:, token_false_id]  
    batch_scores = torch.stack([false_vector, true_vector], dim=1)  
    batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)  
    scores = batch_scores[:, 1].exp().tolist()  
    return scores  

# -----------------  
# FASTAPI 定义部分  
# -----------------  
app = FastAPI()  

class RerankRequest(BaseModel):  
    queries: List[str]  
    documents: List[str]  
    instruction: Optional[str] = None  # 可以不传,使用默认instruction  

class PairRerankRequest(BaseModel):  
    pairs: List[List[str]]  # 允许直接传递已配对([query, doc])的列表  
    instruction: Optional[str] = None  

@app.post("/rerank")  
def rerank(req: RerankRequest):  
    queries = req.queries  
    documents = req.documents  
    instruction = req.instruction  
    if len(queries) != len(documents):  
        return {"error": "queries length and documents length must match one to one."}  
    pairs = [format_instruction(instruction, q, d) for q, d in zip(queries, documents)]  
    inputs = process_inputs(pairs)  
    scores = compute_logits(inputs)  
    return {"scores": scores}  

@app.post("/rerank_pair")  
def rerank_pair(req: PairRerankRequest):  
    instruction = req.instruction  
    pairs = [format_instruction(instruction, q, d) for q, d in req.pairs]  
    inputs = process_inputs(pairs)  
    scores = compute_logits(inputs)  
    return {"scores": scores}

运行Qwen3-Reranker模型并使用API调用

后台运行API


conda init bash
source /root/.bashrc
conda activate sglang
uvicorn main_reranker:app --reload --host 0.0.0.0 --port 8081
INFO:     Will watch for changes in these directories: ['/gemini/code']
INFO:     Uvicorn running on http://0.0.0.0:8081 (Press CTRL+C to quit)
INFO:     Started reloader process [1675] using StatReload
INFO:     Started server process [1677]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
/root/miniconda3/envs/sglang/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:2718: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.
  warnings.warn(
INFO:     127.0.0.1:39716 - "POST /rerank HTTP/1.1" 200 OK

使用API


curl -s -X POST "http://localhost:8081/rerank" \
  -H "Content-Type: application/json" \
  -d '{
  "queries":["What is the capital of China?","Explain gravity"],
  "documents":["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."]}'| jq

返回值

{
  "scores": [
    0.9999999403953552,
    0.0000000596046448015625
  ]
}

使用Transformers运行Qwen3-Embedding模型

main_embedding.py 代码如下

import torch  
import torch.nn.functional as F  
from torch import Tensor  
from transformers import AutoTokenizer, AutoModel  
from fastapi import FastAPI, HTTPException  
from pydantic import BaseModel  
from typing import List  

# 加载本地模型  
MODEL_DIR = "/gemini/pretrain/Qwen3-Embedding-8B"  # 请确保此路径下有完整huggingface格式模型  

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, padding_side='left')  
model = AutoModel.from_pretrained(MODEL_DIR)  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
model = model.to(device)  
model.eval()  

# 池化方式  
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:  
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])  
    if left_padding:  
        return last_hidden_states[:, -1]  
    else:  
        sequence_lengths = attention_mask.sum(dim=1) - 1  
        batch_size = last_hidden_states.shape[0]  
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]  

# 输入参数模型  
class QueryRequest(BaseModel):  
    task: str  # 任务描述,建议如: '检索答案'  
    queries: List[str]  # 查询query列表  
    documents: List[str] # 检索文档列表  

# API app  
app = FastAPI()  

# API路由  
@app.post("/embedding/retrieve")  
def embedding_retrieve(request: QueryRequest):  
    task = request.task  
    queries = [f"Instruct: {task}\nQuery:{q}" for q in request.queries]  
    documents = request.documents  
    input_texts = queries + documents  

    # Tokenize:  
    batch_dict = tokenizer(  
        input_texts,  
        padding=True,  
        truncation=True,  
        max_length=8192,  
        return_tensors="pt"  
    )  
    batch_dict = {k: v.to(device) for k, v in batch_dict.items()}  

    with torch.no_grad():  
        outputs = model(**batch_dict)  
        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])  
        embeddings = F.normalize(embeddings, p=2, dim=1)  # 归一化  

        q_num = len(queries)  
        d_num = len(documents)  
        # 得分==相似度, shape: [q_num, d_num]  
        scores = (embeddings[:q_num] @ embeddings[q_num:].T)  
        scores = scores.cpu().tolist()  

    # 可按实际需求返回详细结果  
    return {  
        "scores": scores,  
        "query_count": q_num,  
        "doc_count": d_num,  
        "query_documents": {"queries": request.queries, "documents": request.documents}  
    }

运行main_embedding.py

uvicorn main_embedding:app --reload --host 0.0.0.0 --port 8080

INFO:     Will watch for changes in these directories: ['/gemini/code']
INFO:     Uvicorn running on http://0.0.0.0:8080 (Press CTRL+C to quit)
INFO:     Started reloader process [1827] using StatReload
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:19<00:00,  4.96s/it]
INFO:     Started server process [1829]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     127.0.0.1:43076 - "POST /embedding/retrieve HTTP/1.1" 200 OK

使用API调用

curl -s -X POST "http://localhost:8080/embedding/retrieve" \
  -H "Content-Type: application/json" \
  -d '{  
    "task": "Given a web search query, retrieve relevant passages that answer the query",  
    "queries": [  
      "What is the capital of China?",  
      "Explain gravity"  
    ],  
    "documents": [  
      "The capital of China is Beijing.",  
      "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."  
    ]  
  }' | jq

返回值

{
  "scores": [
    [
      0.7493016719818115,
      0.07506481558084488
    ],
    [
      0.0879596695303917,
      0.6318400502204895
    ]
  ],
  "query_count": 2,
  "doc_count": 2,
  "query_documents": {
    "queries": [
      "What is the capital of China?",
      "Explain gravity"
    ],
    "documents": [
      "The capital of China is Beijing.",
      "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
    ]
  }
}

参考文献

Copyright © 实验中心 2025            该文件修订时间: 2025-06-11 09:05:22

results matching ""

    No results matching ""