import numpy as np from dataclasses import dataclass import torch from torch import nn from typing import Optional, List, Dict, Tuple from transformers.models.qwen3 import modeling_qwen3 from transformers.modeling_outputs import CausalLMOutputWithPast @dataclass class CausalLMOutputWithScores(CausalLMOutputWithPast): scores: Optional[torch.FloatTensor] = None query_embeds: Optional[torch.FloatTensor] = None doc_embeds: Optional[torch.FloatTensor] = None def sanitize_input(text: str, special_tokens: Dict[str, str]) -> str: for token in special_tokens.values(): text = text.replace(token, "") return text def format_docs_prompts_func( query: str, docs: list[str], instruction: Optional[str] = None, special_tokens: Dict[str, str] = {}, no_thinking: bool = True, ) -> str: query = sanitize_input(query, special_tokens) docs = [sanitize_input(doc, special_tokens) for doc in docs] prefix = ( "<|im_start|>system\n" "You are a search relevance expert who can determine a ranking of the passages based on how relevant they are to the query. " "If the query is a question, how relevant a passage is depends on how well it answers the question. " "If not, try to analyze the intent of the query and assess how well each passage satisfies the intent. " "If an instruction is provided, you should follow the instruction when determining the ranking." "<|im_end|>\n<|im_start|>user\n" ) suffix = "<|im_end|>\n<|im_start|>assistant\n" if no_thinking: suffix += "\n\n\n\n" doc_emb_token = special_tokens["doc_embed_token"] query_emb_token = special_tokens["query_embed_token"] prompt = ( f"I will provide you with {len(docs)} passages, each indicated by a numerical identifier. " f"Rank the passages based on their relevance to query: {query}\n" ) if instruction: prompt += f'\n{instruction}\n\n' doc_prompts = [f'\n{doc}{doc_emb_token}\n' for i, doc in enumerate(docs)] prompt += "\n".join(doc_prompts) + "\n" prompt += f"\n{query}{query_emb_token}\n" return prefix + prompt + suffix class JinaForRanking(modeling_qwen3.Qwen3ForCausalLM): def __init__(self, config): super().__init__(config) self.padding_side = "left" self.projector_dim = 512 self.lm_head = nn.Identity() self.projector = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size // 2, bias=False), nn.ReLU(), nn.Linear(config.hidden_size // 2, self.projector_dim, bias=False), ) self.post_init() self.special_tokens = {"query_embed_token": "<|rerank_token|>", "doc_embed_token": "<|embed_token|>"} self.doc_embed_token_id = 151670 self.query_embed_token_id = 151671 def forward(self, *args, **kwargs) -> CausalLMOutputWithScores: kwargs.pop("output_hidden_states", None) kwargs.pop("use_cache", None) assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()" input_ids = kwargs.pop("input_ids", None) outputs = super().forward( *args, input_ids=input_ids, use_cache=False, output_hidden_states=True, **kwargs, ) hidden_states = outputs.hidden_states[-1] batch_size, _, dim = hidden_states.shape query_embed_token_indexes = torch.eq(input_ids, self.query_embed_token_id) doc_embed_token_indexes = torch.eq(input_ids, self.doc_embed_token_id) doc_embeds = hidden_states[doc_embed_token_indexes].view(batch_size, -1, dim) query_embeds = hidden_states[query_embed_token_indexes].unsqueeze(1) doc_embeds = self.projector(doc_embeds) query_embeds = self.projector(query_embeds) query_embeds_expanded = query_embeds.expand_as(doc_embeds) scores = torch.nn.functional.cosine_similarity(doc_embeds, query_embeds_expanded, dim=-1).squeeze(-1) return CausalLMOutputWithScores( loss=None, logits=None, scores=scores, query_embeds=query_embeds, doc_embeds=doc_embeds, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def _ensure_tokenizer(self): if not hasattr(self, "_tokenizer"): from transformers import AutoTokenizer self._tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True) if self._tokenizer.pad_token is None: self._tokenizer.pad_token = self._tokenizer.unk_token self._tokenizer.pad_token_id = self._tokenizer.convert_tokens_to_ids(self._tokenizer.pad_token) self._tokenizer.padding_side = 'left' def _truncate_texts( self, query: str, documents: List[str], max_query_length: int = 512, max_doc_length: int = 2048, ) -> Tuple[str, List[str], List[int], int]: self._ensure_tokenizer() docs = [] doc_lengths = [] for doc in documents: doc_tokens = self._tokenizer(doc, truncation=True, max_length=max_doc_length) if len(doc_tokens['input_ids']) >= max_doc_length: doc = self._tokenizer.decode(doc_tokens['input_ids']) doc_lengths.append(len(doc_tokens['input_ids'])) docs.append(doc) query_tokens = self._tokenizer(query, truncation=True, max_length=max_query_length) if len(query_tokens['input_ids']) >= max_query_length: query = self._tokenizer.decode(query_tokens['input_ids']) query_length = len(query_tokens['input_ids']) return query, docs, doc_lengths, query_length def _compute_single_batch( self, query: str, docs: List[str], instruction: Optional[str] = None, ) -> CausalLMOutputWithScores: self._ensure_tokenizer() device = next(self.parameters()).device prompt = format_docs_prompts_func( query, docs, instruction=instruction, special_tokens=self.special_tokens, no_thinking=True, ) batch = self._tokenizer( text=[prompt], padding=True, padding_side="left", return_tensors="pt", ).to(device) return self.forward(**batch) def _calculate_cosine_scores( self, query_embeddings: np.ndarray, doc_embeddings: np.ndarray, ) -> np.ndarray: return np.dot(query_embeddings, doc_embeddings.T) / ( np.linalg.norm(query_embeddings) * np.linalg.norm(doc_embeddings, axis=1) ) @torch.no_grad() def rerank( self, query: str, documents: List[str], top_n: Optional[int] = None, return_embeddings: bool = False, max_doc_length: int = 2048, max_query_length: int = 512, ) -> List[dict]: """ Rerank documents by relevance to a query. Args: query: Search query string documents: List of document strings to rank top_n: Return only top N results (default: all) return_embeddings: Include embeddings in output (default: False) Returns: List of dicts with keys: - document: Original document text - relevance_score: Similarity score (higher = more relevant) - index: Position in input documents list - embedding: Doc embedding if return_embeddings=True, else None """ self._ensure_tokenizer() # Derived from model configuration max_length = self._tokenizer.model_max_length # Derive block_size from max_length to fit documents efficiently # Heuristic: allow ~125 docs per batch for typical doc sizes block_size = 125 query, docs, doc_lengths, query_length = self._truncate_texts(query, documents, max_query_length, max_doc_length) length_capacity = max_length - 2 * query_length block_docs = [] doc_embeddings = [] query_embeddings = [] block_weights = [] for length, doc in zip(doc_lengths, docs): block_docs.append(doc) length_capacity -= length if len(block_docs) >= block_size or length_capacity <= max_doc_length: outputs = self._compute_single_batch(query, block_docs, instruction=None) doc_embeddings.extend(outputs.doc_embeds[0].cpu().float().numpy()) query_embeddings.append(outputs.query_embeds[0].cpu().float().numpy()) scores = outputs.scores.view(-1).cpu().float().numpy() block_weights.append(((1.0 + scores) / 2.0).max()) block_docs = [] length_capacity = max_length - 2 * query_length if len(block_docs) > 0: outputs = self._compute_single_batch(query, block_docs, instruction=None) doc_embeddings.extend(outputs.doc_embeds[0].cpu().float().numpy()) query_embeddings.append(outputs.query_embeds[0].cpu().float().numpy()) scores = outputs.scores.view(-1).cpu().float().numpy() block_weights.append(((1.0 + scores) / 2.0).max()) query_embeddings = np.array(query_embeddings) doc_embeddings = np.array(doc_embeddings) query_embeddings = np.average(query_embeddings, axis=0, weights=block_weights) scores = self._calculate_cosine_scores(query_embeddings, doc_embeddings) scores_argsort = np.argsort(scores[0])[::-1] # Derive top_n: if not specified, return all documents if top_n is None: top_n = len(documents) else: top_n = min(top_n, len(documents)) return [ { 'document': documents[scores_argsort[i]], 'relevance_score': scores[0][scores_argsort[i]], 'index': scores_argsort[i], 'embedding': doc_embeddings[scores_argsort[i]] if return_embeddings else None, } for i in range(top_n) ]