【阿里开源】Qwen3-Embedding和Qwen3-Reranker
使用Transformer运行Qwen3-Reranker模型
main_reranker.py 代码如下
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Optional
# ---------------------
# 初始化模型与参数
# ---------------------
MODEL_PATH = "/gemini/pretrain/Qwen3-Reranker-0.6B" # 换成本地模型路径或 huggingface hub 名字
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).eval()
# 若显卡和环境允许可用下面1行启用GPU和flash_attention
# model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")
max_length = 8192
prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
def format_instruction(instruction, query, doc):
if instruction is None:
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(instruction=instruction, query=query, doc=doc)
return output
def process_inputs(pairs):
# pairs: List[str]
inputs = tokenizer(
pairs, padding=False, truncation='longest_first',
return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
return inputs
@torch.no_grad()
def compute_logits(inputs):
batch_scores = model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores
# -----------------
# FASTAPI 定义部分
# -----------------
app = FastAPI()
class RerankRequest(BaseModel):
queries: List[str]
documents: List[str]
instruction: Optional[str] = None # 可以不传,使用默认instruction
class PairRerankRequest(BaseModel):
pairs: List[List[str]] # 允许直接传递已配对([query, doc])的列表
instruction: Optional[str] = None
@app.post("/rerank")
def rerank(req: RerankRequest):
queries = req.queries
documents = req.documents
instruction = req.instruction
if len(queries) != len(documents):
return {"error": "queries length and documents length must match one to one."}
pairs = [format_instruction(instruction, q, d) for q, d in zip(queries, documents)]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)
return {"scores": scores}
@app.post("/rerank_pair")
def rerank_pair(req: PairRerankRequest):
instruction = req.instruction
pairs = [format_instruction(instruction, q, d) for q, d in req.pairs]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)
return {"scores": scores}
运行Qwen3-Reranker模型并使用API调用
后台运行API
conda init bash
source /root/.bashrc
conda activate sglang
uvicorn main_reranker:app --reload --host 0.0.0.0 --port 8081
INFO: Will watch for changes in these directories: ['/gemini/code']
INFO: Uvicorn running on http://0.0.0.0:8081 (Press CTRL+C to quit)
INFO: Started reloader process [1675] using StatReload
INFO: Started server process [1677]
INFO: Waiting for application startup.
INFO: Application startup complete.
You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
/root/miniconda3/envs/sglang/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:2718: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.
warnings.warn(
INFO: 127.0.0.1:39716 - "POST /rerank HTTP/1.1" 200 OK
使用API
curl -s -X POST "http://localhost:8081/rerank" \
-H "Content-Type: application/json" \
-d '{
"queries":["What is the capital of China?","Explain gravity"],
"documents":["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."]}'| jq
返回值
{
"scores": [
0.9999999403953552,
0.0000000596046448015625
]
}
使用Transformers运行Qwen3-Embedding模型
main_embedding.py 代码如下
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
# 加载本地模型
MODEL_DIR = "/gemini/pretrain/Qwen3-Embedding-8B" # 请确保此路径下有完整huggingface格式模型
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, padding_side='left')
model = AutoModel.from_pretrained(MODEL_DIR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# 池化方式
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
# 输入参数模型
class QueryRequest(BaseModel):
task: str # 任务描述,建议如: '检索答案'
queries: List[str] # 查询query列表
documents: List[str] # 检索文档列表
# API app
app = FastAPI()
# API路由
@app.post("/embedding/retrieve")
def embedding_retrieve(request: QueryRequest):
task = request.task
queries = [f"Instruct: {task}\nQuery:{q}" for q in request.queries]
documents = request.documents
input_texts = queries + documents
# Tokenize:
batch_dict = tokenizer(
input_texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt"
)
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
with torch.no_grad():
outputs = model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1) # 归一化
q_num = len(queries)
d_num = len(documents)
# 得分==相似度, shape: [q_num, d_num]
scores = (embeddings[:q_num] @ embeddings[q_num:].T)
scores = scores.cpu().tolist()
# 可按实际需求返回详细结果
return {
"scores": scores,
"query_count": q_num,
"doc_count": d_num,
"query_documents": {"queries": request.queries, "documents": request.documents}
}
运行main_embedding.py
uvicorn main_embedding:app --reload --host 0.0.0.0 --port 8080
INFO: Will watch for changes in these directories: ['/gemini/code']
INFO: Uvicorn running on http://0.0.0.0:8080 (Press CTRL+C to quit)
INFO: Started reloader process [1827] using StatReload
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:19<00:00, 4.96s/it]
INFO: Started server process [1829]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: 127.0.0.1:43076 - "POST /embedding/retrieve HTTP/1.1" 200 OK
使用API调用
curl -s -X POST "http://localhost:8080/embedding/retrieve" \
-H "Content-Type: application/json" \
-d '{
"task": "Given a web search query, retrieve relevant passages that answer the query",
"queries": [
"What is the capital of China?",
"Explain gravity"
],
"documents": [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
}' | jq
返回值
{
"scores": [
[
0.7493016719818115,
0.07506481558084488
],
[
0.0879596695303917,
0.6318400502204895
]
],
"query_count": 2,
"doc_count": 2,
"query_documents": {
"queries": [
"What is the capital of China?",
"Explain gravity"
],
"documents": [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
}
}