shubhamg2208 commited on
Commit
123a7c2
·
verified ·
1 Parent(s): f8f0543

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ results_*
README.md ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ license_name: apache-2.0
4
+ license_link: https://www.apache.org/licenses/LICENSE-2.0
5
+ tags:
6
+ - text
7
+ - image
8
+ - video
9
+ - multimodal-embedding
10
+ - vidore
11
+ - colpali
12
+ - colqwen3
13
+ - multilingual-embedding
14
+ - quantized
15
+ - awq
16
+ - autoround
17
+ - w4a16
18
+ language:
19
+ - multilingual
20
+ library_name: transformers
21
+ pipeline_tag: visual-document-retrieval
22
+ base_model:
23
+ - TomoroAI/tomoro-colqwen3-embed-4b
24
+ ---
25
+
26
+ # TomoroAI/tomoro-ai-colqwen3-embed-4b-awq
27
+
28
+ ## Overview
29
+
30
+ This is a **W4A16 quantized** version of [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b), a state-of-the-art [ColPali](https://arxiv.org/abs/2407.01449)-style multimodal embedding model. The quantization was performed using [AutoRound](https://github.com/intel/auto-round) with AutoAWQ backend.
31
+
32
+ The quantized model achieves **~3.5 GB memory usage** (vs 8.4 GB for the original), enabling deployment on consumer GPUs while maintaining competitive retrieval performance.
33
+
34
+ ## Model Details
35
+
36
+ | Property | Value |
37
+ |----------|-------|
38
+ | **Original Model** | [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b) |
39
+ | **Parameters** | 4.0B |
40
+ | **Quantization** | W4A16 (4-bit weights, 16-bit activations) |
41
+ | **Quantization Method** | AutoRound with AutoAWQ backend |
42
+ | **Calibration Sequence Length** | 1024 |
43
+ | **Memory Usage (Quantized)** | ~3.5 GB |
44
+ | **Memory Usage (Original)** | 8.4 GB |
45
+ | **Embedding Dimension** | 320 |
46
+ | **Max Visual Tokens** | 1280 |
47
+
48
+ ## Quantization Configuration
49
+
50
+ | Parameter | Value |
51
+ |-----------|-------|
52
+ | **Bits** | 4 |
53
+ | **Group Size** | 128 |
54
+ | **Symmetric** | True |
55
+ | **Calibration Dataset** | NeelNanda/pile-10k (AutoRound default) |
56
+ | **Calibration Sequence Length** | 1024 |
57
+ | **Iterations** | 1000 |
58
+ | **Number of Samples** | 560 |
59
+ | **Batch Size** | 80 |
60
+ | **Quantized Layers** | 252 |
61
+ | **FP16 Layers (Vision)** | 105 |
62
+
63
+ > **Note:** Only the text tower (language model) is quantized. The vision encoder remains in FP16/BF16 to preserve visual feature quality.
64
+
65
+ ## Performance
66
+
67
+ ### NDCG@5 on ViDoRe Benchmark (All Languages)
68
+
69
+ | Model | Average NDCG@5 | Change |
70
+ |-------|----------------|--------|
71
+ | Original (FP16) | 0.70023 | - |
72
+ | **This Model (W4A16, seqlen=1024)** | **0.69768** | **-0.36%** |
73
+
74
+ ### NDCG@5 on ViDoRe Benchmark (English Only)
75
+
76
+ | Model | Average NDCG@5 | Change |
77
+ |-------|----------------|--------|
78
+ | Original (FP16) | 0.74743 | - |
79
+ | **This Model (W4A16, seqlen=1024)** | **0.74582** | **-0.21%** |
80
+
81
+ ### Performance Summary
82
+
83
+ - **Benchmarks Improved:** 17
84
+ - **Benchmarks Degraded:** 23
85
+ - **Overall Quality Retention:** ~99.6%
86
+
87
+ ### Benchmark Comparison Charts
88
+
89
+ > **Note:** Here, "seqlen" refers to the **calibration dataset sequence length used during quantization**, not the maximum sequence length supported by the original model. The model retains the full sequence length of the original, but quantization statistics are collected with the calibration seqlen shown.
90
+
91
+
92
+ #### Performance Comparison (All Languages)
93
+
94
+ ![Performance Comparison - All Languages](https://raw.githubusercontent.com/goodhamgupta/evaluation/main/performance_comparison_4B_all_languages.png)
95
+
96
+ #### Performance Difference vs Original (All Languages)
97
+
98
+ ![Performance Difference - All Languages](https://raw.githubusercontent.com/goodhamgupta/evaluation/main/performance_diff_4B_all_languages.png)
99
+
100
+ #### Performance Comparison (English Only)
101
+
102
+ ![Performance Comparison - English](https://raw.githubusercontent.com/goodhamgupta/evaluation/main/performance_comparison_4B_english.png)
103
+
104
+ #### Performance Difference vs Original (English Only)
105
+
106
+ ![Performance Difference - English](https://raw.githubusercontent.com/goodhamgupta/evaluation/main/performance_diff_4B_english.png)
107
+
108
+ ## Memory Efficiency
109
+
110
+ The quantized model enables deployment on GPUs with limited memory:
111
+
112
+ | GPU Memory | Original Model | Quantized Model |
113
+ |------------|----------------|-----------------|
114
+ | 8 GB | Marginal | Fits with batch size ~64 |
115
+ | 12 GB | Fits comfortably | Fits with batch size ~256 |
116
+ | 16 GB | Fits comfortably | High batch sizes possible |
117
+ | 24 GB | Fits comfortably | High batch sizes possible |
118
+
119
+ ## Usage
120
+
121
+ ### Prerequisites
122
+
123
+ ```bash
124
+ pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/cu128
125
+ pip install auto-round==0.9.2
126
+ pip install autoawq==0.2.9
127
+ pip install transformers pillow requests
128
+ pip install flash-attn --no-build-isolation # Optional but recommended
129
+ ```
130
+
131
+ ### Inference Code
132
+
133
+ ```python
134
+ import torch
135
+ from transformers import AutoModel, AutoProcessor
136
+ from PIL import Image
137
+ import requests
138
+ from io import BytesIO
139
+
140
+ # Configuration
141
+ MODEL_ID = "shubhamg2208/tomoro-ai-colqwen3-embed-4b-w4a16-autoawq-seqlen-1024"
142
+ DTYPE = torch.bfloat16
143
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
144
+
145
+ # Load Model & Processor
146
+ processor = AutoProcessor.from_pretrained(
147
+ MODEL_ID,
148
+ trust_remote_code=True,
149
+ max_num_visual_tokens=1280,
150
+ )
151
+ model = AutoModel.from_pretrained(
152
+ MODEL_ID,
153
+ dtype=DTYPE,
154
+ attn_implementation="sdpa", # Use "flash_attention_2" if available
155
+ trust_remote_code=True,
156
+ device_map=DEVICE,
157
+ ).eval()
158
+
159
+ # Sample queries and documents
160
+ queries = [
161
+ "Retrieve the city of Singapore",
162
+ "Retrieve the city of Beijing",
163
+ ]
164
+ doc_urls = [
165
+ "https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg",
166
+ "https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG",
167
+ ]
168
+
169
+ def load_image(url: str) -> Image.Image:
170
+ headers = {"User-Agent": "Mozilla/5.0"}
171
+ resp = requests.get(url, headers=headers, timeout=10)
172
+ resp.raise_for_status()
173
+ return Image.open(BytesIO(resp.content)).convert("RGB")
174
+
175
+ def encode_queries(texts):
176
+ batch = processor.process_texts(texts=texts)
177
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
178
+ with torch.inference_mode():
179
+ out = model(**batch)
180
+ return out.embeddings.to(torch.bfloat16).cpu()
181
+
182
+ def encode_docs(urls):
183
+ images = [load_image(url) for url in urls]
184
+ features = processor.process_images(images=images)
185
+ features = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in features.items()}
186
+ with torch.inference_mode():
187
+ out = model(**features)
188
+ return out.embeddings.to(torch.bfloat16).cpu()
189
+
190
+ # Encode and score
191
+ query_embeddings = encode_queries(queries)
192
+ doc_embeddings = encode_docs(doc_urls)
193
+ scores = processor.score_multi_vector(query_embeddings, doc_embeddings)
194
+ print(scores)
195
+ ```
196
+
197
+ ## Comparison with Other Calibration Lengths
198
+
199
+ | Calibration Length | Avg NDCG@5 | Delta | Best For |
200
+ |--------------------|------------|-------|----------|
201
+ | seqlen=256 | 0.69611 | -0.59% | Short document retrieval |
202
+ | seqlen=512 | 0.69696 | -0.47% | Balanced use cases |
203
+ | seqlen=1024 | 0.69768 | -0.36% | Long document retrieval |
204
+
205
+ ## Limitations
206
+
207
+ - **Reduced Precision:** 4-bit quantization introduces some accuracy loss compared to the original FP16 model.
208
+ - **Vision Encoder:** The vision encoder is not quantized to preserve visual feature quality.
209
+ - **Inference Backend:** Performance depends on the inference backend (AutoAWQ, vLLM, etc.).
210
+
211
+ ## License
212
+
213
+ This model is released under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0), consistent with the original model.
214
+
215
+ ## Acknowledgements
216
+
217
+ - **Original Model:** [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b) by [Tomoro AI](https://tomoro.ai/)
218
+ - **Quantization Tool:** [AutoRound](https://github.com/intel/auto-round) by Intel
219
+ - **Base Architecture:** [Qwen3-VL](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct) by Alibaba
220
+
221
+ ## Citation
222
+
223
+ If you use this model, please cite both the original model and this quantized version:
224
+
225
+ ```bibtex
226
+ @misc{huang2025beyond,
227
+ author = {Huang, Xin and Tan, Kye Min},
228
+ title = {Beyond Text: Unlocking True Multimodal, End-to-end RAG with Tomoro ColQwen3},
229
+ year = {2025},
230
+ url = {https://tomoro.ai/insights/beyond-text-unlocking-true-multimodal-end-to-end-rag-with-tomoro-colqwen3},
231
+ publisher = {Tomoro.ai}
232
+ }
233
+
234
+ @misc{autoround,
235
+ author = {Intel Corporation},
236
+ title = {AutoRound: Advanced Weight-Only Quantization Algorithm},
237
+ year = {2024},
238
+ url = {https://github.com/intel/auto-round}
239
+ }
240
+ ```
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
chat_template.jinja ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {%- if messages[0].content is string %}
5
+ {{- messages[0].content }}
6
+ {%- else %}
7
+ {%- for content in messages[0].content %}
8
+ {%- if 'text' in content %}
9
+ {{- content.text }}
10
+ {%- endif %}
11
+ {%- endfor %}
12
+ {%- endif %}
13
+ {{- '\n\n' }}
14
+ {%- endif %}
15
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
16
+ {%- for tool in tools %}
17
+ {{- "\n" }}
18
+ {{- tool | tojson }}
19
+ {%- endfor %}
20
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
21
+ {%- else %}
22
+ {%- if messages[0].role == 'system' %}
23
+ {{- '<|im_start|>system\n' }}
24
+ {%- if messages[0].content is string %}
25
+ {{- messages[0].content }}
26
+ {%- else %}
27
+ {%- for content in messages[0].content %}
28
+ {%- if 'text' in content %}
29
+ {{- content.text }}
30
+ {%- endif %}
31
+ {%- endfor %}
32
+ {%- endif %}
33
+ {{- '<|im_end|>\n' }}
34
+ {%- endif %}
35
+ {%- endif %}
36
+ {%- set image_count = namespace(value=0) %}
37
+ {%- set video_count = namespace(value=0) %}
38
+ {%- for message in messages %}
39
+ {%- if message.role == "user" %}
40
+ {{- '<|im_start|>' + message.role + '\n' }}
41
+ {%- if message.content is string %}
42
+ {{- message.content }}
43
+ {%- else %}
44
+ {%- for content in message.content %}
45
+ {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
46
+ {%- set image_count.value = image_count.value + 1 %}
47
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
48
+ <|vision_start|><|image_pad|><|vision_end|>
49
+ {%- elif content.type == 'video' or 'video' in content %}
50
+ {%- set video_count.value = video_count.value + 1 %}
51
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
52
+ <|vision_start|><|video_pad|><|vision_end|>
53
+ {%- elif 'text' in content %}
54
+ {{- content.text }}
55
+ {%- endif %}
56
+ {%- endfor %}
57
+ {%- endif %}
58
+ {{- '<|im_end|>\n' }}
59
+ {%- elif message.role == "assistant" %}
60
+ {{- '<|im_start|>' + message.role + '\n' }}
61
+ {%- if message.content is string %}
62
+ {{- message.content }}
63
+ {%- else %}
64
+ {%- for content_item in message.content %}
65
+ {%- if 'text' in content_item %}
66
+ {{- content_item.text }}
67
+ {%- endif %}
68
+ {%- endfor %}
69
+ {%- endif %}
70
+ {%- if message.tool_calls %}
71
+ {%- for tool_call in message.tool_calls %}
72
+ {%- if (loop.first and message.content) or (not loop.first) %}
73
+ {{- '\n' }}
74
+ {%- endif %}
75
+ {%- if tool_call.function %}
76
+ {%- set tool_call = tool_call.function %}
77
+ {%- endif %}
78
+ {{- '<tool_call>\n{"name": "' }}
79
+ {{- tool_call.name }}
80
+ {{- '", "arguments": ' }}
81
+ {%- if tool_call.arguments is string %}
82
+ {{- tool_call.arguments }}
83
+ {%- else %}
84
+ {{- tool_call.arguments | tojson }}
85
+ {%- endif %}
86
+ {{- '}\n</tool_call>' }}
87
+ {%- endfor %}
88
+ {%- endif %}
89
+ {{- '<|im_end|>\n' }}
90
+ {%- elif message.role == "tool" %}
91
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
92
+ {{- '<|im_start|>user' }}
93
+ {%- endif %}
94
+ {{- '\n<tool_response>\n' }}
95
+ {%- if message.content is string %}
96
+ {{- message.content }}
97
+ {%- else %}
98
+ {%- for content in message.content %}
99
+ {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
100
+ {%- set image_count.value = image_count.value + 1 %}
101
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
102
+ <|vision_start|><|image_pad|><|vision_end|>
103
+ {%- elif content.type == 'video' or 'video' in content %}
104
+ {%- set video_count.value = video_count.value + 1 %}
105
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
106
+ <|vision_start|><|video_pad|><|vision_end|>
107
+ {%- elif 'text' in content %}
108
+ {{- content.text }}
109
+ {%- endif %}
110
+ {%- endfor %}
111
+ {%- endif %}
112
+ {{- '\n</tool_response>' }}
113
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
114
+ {{- '<|im_end|>\n' }}
115
+ {%- endif %}
116
+ {%- endif %}
117
+ {%- endfor %}
118
+ {%- if add_generation_prompt %}
119
+ {{- '<|im_start|>assistant\n' }}
120
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ColQwen3"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_colqwen3.ColQwen3Config",
7
+ "AutoModel": "modeling_colqwen3.ColQwen3"
8
+ },
9
+ "dtype": "bfloat16",
10
+ "embed_dim": 320,
11
+ "image_token_id": 151655,
12
+ "initializer_range": 0.02,
13
+ "max_num_visual_tokens": 1280,
14
+ "model_type": "colqwen3",
15
+ "padding_side": "left",
16
+ "quantization_config": {
17
+ "autoround_version": "0.9.2",
18
+ "batch_size": 80,
19
+ "bits": 4,
20
+ "block_name_to_quantize": "vlm.model.language_model.layers",
21
+ "data_type": "int",
22
+ "group_size": 128,
23
+ "iters": 1000,
24
+ "nsamples": 560,
25
+ "packing_format": "auto_round:auto_gptq",
26
+ "quant_method": "auto-round",
27
+ "seqlen": 1024,
28
+ "sym": true
29
+ },
30
+ "text_config": {
31
+ "attention_bias": false,
32
+ "attention_dropout": 0.0,
33
+ "bos_token_id": 151643,
34
+ "dtype": "bfloat16",
35
+ "eos_token_id": 151645,
36
+ "head_dim": 128,
37
+ "hidden_act": "silu",
38
+ "hidden_size": 2560,
39
+ "initializer_range": 0.02,
40
+ "intermediate_size": 9728,
41
+ "max_position_embeddings": 262144,
42
+ "model_type": "qwen3_vl_text",
43
+ "num_attention_heads": 32,
44
+ "num_hidden_layers": 36,
45
+ "num_key_value_heads": 8,
46
+ "rms_norm_eps": 1e-06,
47
+ "rope_scaling": {
48
+ "mrope_interleaved": true,
49
+ "mrope_section": [
50
+ 24,
51
+ 20,
52
+ 20
53
+ ],
54
+ "rope_type": "default"
55
+ },
56
+ "rope_theta": 5000000,
57
+ "tie_word_embeddings": true,
58
+ "use_cache": true,
59
+ "vocab_size": 151936
60
+ },
61
+ "transformers_version": "4.57.3",
62
+ "video_token_id": 151656,
63
+ "vision_config": {
64
+ "deepstack_visual_indexes": [
65
+ 5,
66
+ 11,
67
+ 17
68
+ ],
69
+ "depth": 24,
70
+ "dtype": "bfloat16",
71
+ "hidden_act": "gelu_pytorch_tanh",
72
+ "hidden_size": 1024,
73
+ "in_channels": 3,
74
+ "initializer_range": 0.02,
75
+ "intermediate_size": 4096,
76
+ "model_type": "qwen3_vl",
77
+ "num_heads": 16,
78
+ "num_position_embeddings": 2304,
79
+ "out_hidden_size": 2560,
80
+ "patch_size": 16,
81
+ "spatial_merge_size": 2,
82
+ "temporal_patch_size": 2
83
+ },
84
+ "vision_end_token_id": 151653,
85
+ "vision_start_token_id": 151652
86
+ }
configuration_colqwen3.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ Configuration for ColQwen3, adapted to mirror the ColQwen2 structure.
17
+ """
18
+
19
+ from copy import deepcopy
20
+ from typing import Any
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.models.auto import CONFIG_MAPPING
24
+ from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig, Qwen3VLVisionConfig
25
+ from transformers.utils import logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class ColQwen3Config(PretrainedConfig):
32
+ """Configuration for ColQwen3 retrieval model."""
33
+
34
+ model_type = "colqwen3"
35
+ sub_configs: dict[str, Any] = {"vision_config": Qwen3VLVisionConfig, "text_config": Qwen3VLTextConfig}
36
+
37
+ def __init__(
38
+ self,
39
+ vision_config: Any = None,
40
+ text_config: Any = None,
41
+ embed_dim: int = 320,
42
+ padding_side: str = "left",
43
+ initializer_range: float = 0.02,
44
+ dtype: str | None = None,
45
+ **kwargs,
46
+ ):
47
+ if vision_config is None or text_config is None:
48
+ base_vlm_config = CONFIG_MAPPING["qwen3_vl"]()
49
+ if vision_config is None:
50
+ vision_config = deepcopy(base_vlm_config.vision_config)
51
+ logger.info("`vision_config` is `None`. Initializing with the default `Qwen3VLVisionConfig`.")
52
+ if text_config is None:
53
+ text_config = deepcopy(base_vlm_config.text_config)
54
+ logger.info("`text_config` is `None`. Initializing with the default `Qwen3VLTextConfig`.")
55
+
56
+ if isinstance(vision_config, dict):
57
+ vision_config = Qwen3VLVisionConfig(**deepcopy(vision_config))
58
+ elif not isinstance(vision_config, PretrainedConfig):
59
+ raise TypeError(
60
+ f"Invalid type for `vision_config`. Expected `PretrainedConfig`, `dict`, or `None`, got {type(vision_config)}."
61
+ )
62
+
63
+ if isinstance(text_config, dict):
64
+ text_config = Qwen3VLTextConfig(**deepcopy(text_config))
65
+ elif not isinstance(text_config, PretrainedConfig):
66
+ raise TypeError(
67
+ f"Invalid type for `text_config`. Expected `PretrainedConfig`, `dict`, or `None`, got {type(text_config)}."
68
+ )
69
+
70
+ if embed_dim <= 0:
71
+ raise ValueError(f"`embed_dim` must be positive, got {embed_dim}.")
72
+
73
+ super().__init__(**kwargs)
74
+ self.vision_config = vision_config
75
+ self.text_config = text_config
76
+ self.embed_dim = embed_dim
77
+ self.padding_side = padding_side
78
+ self.initializer_range = initializer_range
79
+ # Preserve incoming dtype so downstream models avoid attribute errors
80
+ self.dtype = dtype or getattr(self, "dtype", None)
81
+
82
+ @classmethod
83
+ def from_base_config(cls, base_config: PretrainedConfig) -> "ColQwen3Config":
84
+ """Upgrade a base Qwen3VLConfig-like config into ColQwen3Config."""
85
+ if isinstance(base_config, dict):
86
+ data = dict(base_config)
87
+ else:
88
+ data = base_config.to_dict()
89
+
90
+ vision_cfg = data.get("vision_config")
91
+ if isinstance(vision_cfg, dict):
92
+ data["vision_config"] = Qwen3VLVisionConfig.from_dict(vision_cfg)
93
+
94
+ text_cfg = data.get("text_config")
95
+ if isinstance(text_cfg, dict):
96
+ data["text_config"] = Qwen3VLTextConfig.from_dict(text_cfg)
97
+
98
+ data.setdefault("model_type", cls.model_type)
99
+ if hasattr(base_config, "dtype"):
100
+ data.setdefault("dtype", getattr(base_config, "dtype"))
101
+ elif hasattr(base_config, "torch_dtype") and base_config.torch_dtype is not None:
102
+ data.setdefault("dtype", str(base_config.torch_dtype))
103
+
104
+ return cls.from_dict(data)
105
+
106
+ def get_text_config(self, *args, **kwargs) -> PretrainedConfig:
107
+ return self.text_config
108
+
109
+
110
+ DEFAULT_CONFIG = ColQwen3Config()
111
+
112
+ __all__ = ["ColQwen3Config", "DEFAULT_CONFIG"]
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2ba6432335afb9af877c74aa293d536bd2f4178d4f75b77c52732ceaeb331d4
3
+ size 3498416592
modeling_colqwen3.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ Modeling for ColQwen3 retrieval, aligned with the ColQwen2 reference implementation.
17
+ """
18
+
19
+ from dataclasses import dataclass
20
+ from typing import Optional
21
+
22
+ from torch import nn
23
+ from transformers import AutoModelForImageTextToText
24
+ from transformers.configuration_utils import PretrainedConfig
25
+ from transformers.cache_utils import Cache
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging
28
+ from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig
29
+
30
+ from .configuration_colqwen3 import ColQwen3Config
31
+
32
+
33
+ if is_torch_available():
34
+ import torch
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ @auto_docstring
40
+ class ColQwen3PreTrainedModel(PreTrainedModel):
41
+ config_class = ColQwen3Config
42
+ base_model_prefix = "model"
43
+ _no_split_modules = []
44
+ _supports_sdpa = True
45
+ _supports_flash_attn = True
46
+ _supports_flex_attn = True
47
+
48
+ def _init_weights(self, module):
49
+ std = (
50
+ self.config.initializer_range
51
+ if hasattr(self.config, "initializer_range")
52
+ else getattr(self.config.text_config, "initializer_range", 0.02)
53
+ )
54
+
55
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
56
+ module.weight.data.normal_(mean=0.0, std=std)
57
+ if module.bias is not None:
58
+ module.bias.data.zero_()
59
+ elif isinstance(module, nn.Embedding):
60
+ module.weight.data.normal_(mean=0.0, std=std)
61
+ if module.padding_idx is not None:
62
+ module.weight.data[module.padding_idx].zero_()
63
+
64
+
65
+ @dataclass
66
+ @auto_docstring(
67
+ custom_intro="""
68
+ Base class for ColQwen3 embeddings output.
69
+ """
70
+ )
71
+ class ColQwen3ForRetrievalOutput(ModelOutput):
72
+ r"""
73
+ embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
74
+ The embeddings of the model.
75
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
76
+ It is a [`~cache_utils.Cache`] instance.
77
+ """
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+ embeddings: Optional[torch.Tensor] = None
81
+ past_key_values: Optional[Cache] = None
82
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
83
+ attentions: Optional[tuple[torch.FloatTensor]] = None
84
+
85
+
86
+ @auto_docstring(
87
+ custom_intro="""
88
+ ColQwen3 retrieval model that mirrors the ColQwen2 late-interaction pipeline while using a Qwen3-VL backbone.
89
+ """
90
+ )
91
+ class ColQwen3(ColQwen3PreTrainedModel):
92
+ _checkpoint_conversion_mapping = {
93
+ # Legacy checkpoints saved from a bare Qwen3VLModel (no `vlm.` nesting).
94
+ r"^model\.visual": "vlm.model.visual",
95
+ r"^model\.language_model": "vlm.model.language_model",
96
+ r"^model\.": "vlm.model.",
97
+ r"^visual": "vlm.model.visual",
98
+ r"^language_model": "vlm.model.language_model",
99
+ r"^custom_text_proj": "embedding_proj_layer",
100
+ }
101
+ config_class = ColQwen3Config
102
+ model_type = ColQwen3Config.model_type
103
+
104
+ def __init__(
105
+ self,
106
+ config: ColQwen3Config,
107
+ attn_impl: Optional[str] = None,
108
+ mask_non_image_embeddings: bool = False,
109
+ ):
110
+ """
111
+ Args:
112
+ config (ColQwen3Config): Configuration carrying nested vision/text configs for the retrieval model.
113
+ attn_impl (Optional[str], optional): Attention implementation forwarded to the VLM (e.g., "flash_attention_2"). Defaults to None.
114
+ mask_non_image_embeddings (bool, optional): If True, zero out non-image embeddings after projection. Defaults to False.
115
+ """
116
+ super().__init__(config)
117
+ self.config = config
118
+
119
+ vision_cfg = (
120
+ config.vision_config.to_dict() if isinstance(config.vision_config, PretrainedConfig) else config.vision_config
121
+ )
122
+ text_cfg = config.text_config.to_dict() if isinstance(config.text_config, PretrainedConfig) else config.text_config
123
+
124
+ vlm_config = Qwen3VLConfig(
125
+ text_config=text_cfg,
126
+ vision_config=vision_cfg,
127
+ image_token_id=getattr(config, "image_token_id", 151655),
128
+ video_token_id=getattr(config, "video_token_id", 151656),
129
+ vision_start_token_id=getattr(config, "vision_start_token_id", 151652),
130
+ vision_end_token_id=getattr(config, "vision_end_token_id", 151653),
131
+ tie_word_embeddings=getattr(config.text_config, "tie_word_embeddings", False),
132
+ )
133
+ self.vlm = AutoModelForImageTextToText.from_config(vlm_config)
134
+
135
+ self.embedding_dim = self.config.embed_dim
136
+ self.embedding_proj_layer = nn.Linear(
137
+ self.vlm.config.text_config.hidden_size,
138
+ self.embedding_dim,
139
+ )
140
+ self.padding_side = getattr(config, "padding_side", "left")
141
+ self.mask_non_image_embeddings = mask_non_image_embeddings
142
+ self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]
143
+
144
+ self.post_init()
145
+
146
+ if attn_impl is not None and hasattr(self.vlm, "set_attn_implementation"):
147
+ self.vlm.set_attn_implementation(attn_impl)
148
+
149
+ @classmethod
150
+ def from_pretrained(cls, *args, config: Optional[ColQwen3Config] = None, **kwargs):
151
+ key_mapping = kwargs.pop("key_mapping", None)
152
+ if key_mapping is None:
153
+ key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
154
+
155
+ return super().from_pretrained(*args, config=config, **kwargs, key_mapping=key_mapping)
156
+
157
+ @can_return_tuple
158
+ @auto_docstring
159
+ def forward(
160
+ self,
161
+ input_ids: Optional[torch.LongTensor] = None,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ position_ids: Optional[torch.LongTensor] = None,
164
+ past_key_values: Optional[Cache] = None,
165
+ inputs_embeds: Optional[torch.FloatTensor] = None,
166
+ labels: Optional[torch.LongTensor] = None,
167
+ use_cache: Optional[bool] = None,
168
+ output_attentions: Optional[bool] = None,
169
+ output_hidden_states: Optional[bool] = None,
170
+ return_dict: Optional[bool] = None,
171
+ pixel_values: Optional[torch.Tensor] = None,
172
+ image_grid_thw: Optional[torch.LongTensor] = None,
173
+ cache_position: Optional[torch.LongTensor] = None,
174
+ pixel_values_videos: Optional[torch.Tensor] = None,
175
+ video_grid_thw: Optional[torch.LongTensor] = None,
176
+ ) -> ColQwen3ForRetrievalOutput:
177
+ r"""
178
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
179
+ The temporal, height and width of feature shape of each image in LLM.
180
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
181
+ The temporal, height and width of feature shape of each video in LLM.
182
+ """
183
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
184
+ output_hidden_states = (
185
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
186
+ )
187
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
188
+
189
+ vlm_output = self.vlm.model(
190
+ input_ids=input_ids,
191
+ position_ids=position_ids,
192
+ attention_mask=attention_mask,
193
+ past_key_values=past_key_values,
194
+ inputs_embeds=inputs_embeds,
195
+ pixel_values_videos=pixel_values_videos,
196
+ use_cache=use_cache,
197
+ output_attentions=output_attentions,
198
+ output_hidden_states=output_hidden_states,
199
+ return_dict=return_dict,
200
+ pixel_values=pixel_values,
201
+ image_grid_thw=image_grid_thw,
202
+ video_grid_thw=video_grid_thw,
203
+ cache_position=cache_position,
204
+ )
205
+
206
+ vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
207
+
208
+ last_hidden_states = vlm_output[0]
209
+ proj_dtype = self.embedding_proj_layer.weight.dtype
210
+ embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype))
211
+
212
+ denom = embeddings.norm(dim=-1, keepdim=True).clamp_min(torch.finfo(embeddings.dtype).eps)
213
+ embeddings = embeddings / denom
214
+ if attention_mask is not None:
215
+ embeddings = embeddings * attention_mask.unsqueeze(-1)
216
+
217
+ if pixel_values is not None and self.mask_non_image_embeddings:
218
+ image_mask = (input_ids == self.vlm.config.image_token_id).unsqueeze(-1)
219
+ embeddings = embeddings * image_mask
220
+
221
+ return ColQwen3ForRetrievalOutput(
222
+ embeddings=embeddings,
223
+ past_key_values=vlm_output.past_key_values,
224
+ hidden_states=vlm_hidden_states,
225
+ attentions=vlm_output.attentions,
226
+ )
227
+
228
+ def get_input_embeddings(self):
229
+ return self.vlm.get_input_embeddings()
230
+
231
+ def set_input_embeddings(self, value):
232
+ self.vlm.set_input_embeddings(value)
233
+
234
+ def get_output_embeddings(self):
235
+ return self.vlm.get_output_embeddings()
236
+
237
+ def set_output_embeddings(self, new_embeddings):
238
+ self.vlm.set_output_embeddings(new_embeddings)
239
+
240
+ def tie_weights(self):
241
+ return self.vlm.tie_weights()
242
+
243
+ def resize_token_embeddings(
244
+ self,
245
+ new_num_tokens: Optional[int] = None,
246
+ pad_to_multiple_of: Optional[int] = None,
247
+ mean_resizing: bool = True,
248
+ ) -> nn.Embedding:
249
+ model_embeds = self.vlm.resize_token_embeddings(
250
+ new_num_tokens=new_num_tokens,
251
+ pad_to_multiple_of=pad_to_multiple_of,
252
+ mean_resizing=mean_resizing,
253
+ )
254
+
255
+ self.vlm.config.text_config.vocab_size = model_embeds.num_embeddings
256
+ self.vlm.config.vocab_size = model_embeds.num_embeddings
257
+ return model_embeds
258
+
259
+
260
+ __all__ = ["ColQwen3", "ColQwen3PreTrainedModel", "ColQwen3ForRetrievalOutput"]
preprocessor_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_colqwen3.ColQwen3Processor"
4
+ },
5
+ "crop_size": null,
6
+ "data_format": "channels_first",
7
+ "default_to_square": true,
8
+ "device": null,
9
+ "disable_grouping": null,
10
+ "do_center_crop": null,
11
+ "do_convert_rgb": true,
12
+ "do_normalize": true,
13
+ "do_pad": null,
14
+ "do_rescale": true,
15
+ "do_resize": true,
16
+ "image_mean": [
17
+ 0.5,
18
+ 0.5,
19
+ 0.5
20
+ ],
21
+ "image_processor_type": "Qwen2VLImageProcessorFast",
22
+ "image_std": [
23
+ 0.5,
24
+ 0.5,
25
+ 0.5
26
+ ],
27
+ "input_data_format": null,
28
+ "max_pixels": 1310720,
29
+ "merge_size": 2,
30
+ "min_pixels": null,
31
+ "pad_size": null,
32
+ "patch_size": 16,
33
+ "processor_class": "ColQwen3Processor",
34
+ "resample": 3,
35
+ "rescale_factor": 0.00392156862745098,
36
+ "return_tensors": null,
37
+ "size": {
38
+ "longest_edge": 1310720,
39
+ "shortest_edge": 65536
40
+ },
41
+ "temporal_patch_size": 2
42
+ }
processing_colqwen3.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ Processing utilities for ColQwen3, aligned with the ColQwen2 reference implementation.
17
+ """
18
+
19
+ import importlib
20
+ import numpy as np
21
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from PIL import Image
25
+ from transformers import BatchEncoding
26
+ from transformers.feature_extraction_utils import BatchFeature
27
+ from transformers.image_utils import ImageInput, is_valid_image
28
+ from transformers.processing_utils import AudioInput, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideoInput
29
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
30
+ from transformers.utils import logging
31
+
32
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ try:
37
+ from fast_plaid import search
38
+ except ImportError:
39
+ logger.info(
40
+ "FastPlaid is not installed.If you want to use it:Instal with `pip install --no-deps fast-plaid fastkmeans`"
41
+ )
42
+
43
+
44
+ def get_torch_device(device: str = "auto") -> str:
45
+ """Resolve a torch device string with a simple auto mode."""
46
+ if device == "auto":
47
+ if torch.cuda.is_available():
48
+ device = "cuda:0"
49
+ elif torch.backends.mps.is_available(): # for Apple Silicon
50
+ device = "mps"
51
+ else:
52
+ device = "cpu"
53
+ return device
54
+
55
+
56
+ class ColQwen3ProcessorKwargs(ProcessingKwargs, total=False):
57
+ _defaults = {
58
+ "text_kwargs": {
59
+ "padding": "longest",
60
+ },
61
+ "images_kwargs": {
62
+ "data_format": "channels_first",
63
+ "do_convert_rgb": True,
64
+ },
65
+ "videos_kwargs": {
66
+ "return_metadata": True,
67
+ "data_format": "channels_first",
68
+ "do_convert_rgb": True,
69
+ },
70
+ "common_kwargs": {"return_tensors": "pt"},
71
+ }
72
+
73
+
74
+ class ColQwen3Processor(ProcessorMixin):
75
+ """
76
+ Constructs a ColQwen3 processor which wraps a Qwen3VLProcessor with retrieval-specific helpers.
77
+ """
78
+
79
+ attributes = ["image_processor", "tokenizer", "video_processor"]
80
+ image_processor_class = "AutoImageProcessor"
81
+ video_processor_class = "AutoVideoProcessor"
82
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
83
+
84
+ def __init__(
85
+ self,
86
+ image_processor=None,
87
+ tokenizer=None,
88
+ video_processor=None,
89
+ chat_template=None,
90
+ visual_prompt_prefix: Optional[str] = None,
91
+ visual_prompt_suffix: Optional[str] = None,
92
+ video_prompt_prefix: Optional[str] = None,
93
+ video_prompt_suffix: Optional[str] = None,
94
+ query_prefix: Optional[str] = None,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
98
+ self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
99
+ self.image_token_id = (
100
+ tokenizer.image_token_id
101
+ if getattr(tokenizer, "image_token_id", None)
102
+ else tokenizer.convert_tokens_to_ids(self.image_token)
103
+ )
104
+ self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
105
+ self.video_token_id = (
106
+ tokenizer.video_token_id
107
+ if getattr(tokenizer, "video_token_id", None)
108
+ else tokenizer.convert_tokens_to_ids(self.video_token)
109
+ )
110
+ self.vision_start_token = (
111
+ "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
112
+ )
113
+ self.vision_end_token = (
114
+ "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
115
+ )
116
+ self.vision_start_token_id = (
117
+ tokenizer.vision_start_token_id
118
+ if getattr(tokenizer, "vision_start_token_id", None)
119
+ else tokenizer.convert_tokens_to_ids(self.vision_start_token)
120
+ )
121
+ self.vision_end_token_id = (
122
+ tokenizer.vision_end_token_id
123
+ if getattr(tokenizer, "vision_end_token_id", None)
124
+ else tokenizer.convert_tokens_to_ids(self.vision_end_token)
125
+ )
126
+
127
+ if visual_prompt_prefix is None:
128
+ visual_prompt_prefix = (
129
+ "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image."
130
+ )
131
+ self.visual_prompt_prefix = visual_prompt_prefix
132
+ if visual_prompt_suffix is None:
133
+ visual_prompt_suffix = "<|im_end|><|endoftext|>"
134
+ self.visual_prompt_suffix = visual_prompt_suffix
135
+
136
+ if video_prompt_prefix is None:
137
+ video_prompt_prefix = (
138
+ "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>Describe the video."
139
+ )
140
+ self.video_prompt_prefix = video_prompt_prefix
141
+ if video_prompt_suffix is None:
142
+ video_prompt_suffix = "<|im_end|><|endoftext|>"
143
+ self.video_prompt_suffix = video_prompt_suffix
144
+
145
+ if query_prefix is None:
146
+ query_prefix = ""
147
+ self.query_prefix = query_prefix
148
+ self.tokenizer.padding_side = "left"
149
+
150
+ @classmethod
151
+ def from_pretrained( # type: ignore[override]
152
+ cls,
153
+ *args: Any,
154
+ max_num_visual_tokens: int = 1280,
155
+ **kwargs: Any,
156
+ ) -> "ColQwen3Processor":
157
+ instance = super().from_pretrained(
158
+ *args,
159
+ **kwargs,
160
+ )
161
+
162
+ patch_size = getattr(instance.image_processor, "patch_size", None)
163
+ merge_size = getattr(instance.image_processor, "merge_size", None) or getattr(
164
+ instance.image_processor, "spatial_merge_size", None
165
+ )
166
+ if patch_size is None or merge_size is None:
167
+ raise ValueError("Qwen3VL image processor is missing `patch_size` or `merge_size`/`spatial_merge_size`.")
168
+ tile = patch_size * merge_size
169
+ instance.image_processor.max_pixels = max_num_visual_tokens * tile * tile
170
+ instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
171
+
172
+ video_patch_size = getattr(instance.video_processor, "patch_size", None)
173
+ video_merge_size = getattr(instance.video_processor, "merge_size", None) or getattr(
174
+ instance.video_processor, "spatial_merge_size", None
175
+ )
176
+ video_temporal_patch_size = getattr(instance.video_processor, "temporal_patch_size", None)
177
+ if video_patch_size is None or video_merge_size is None or video_temporal_patch_size is None:
178
+ raise ValueError(
179
+ "Qwen3VL video processor is missing `patch_size`, `merge_size`/`spatial_merge_size`, or `temporal_patch_size`."
180
+ )
181
+ video_tile = video_patch_size * video_merge_size
182
+ # Include temporal patching so the visual token cap applies across space and time.
183
+ instance.video_processor.max_pixels = max_num_visual_tokens * video_tile * video_tile * video_temporal_patch_size
184
+ instance.video_processor.size["longest_edge"] = instance.video_processor.max_pixels
185
+
186
+ return instance
187
+
188
+ def __call__(
189
+ self,
190
+ images: Optional[ImageInput] = None,
191
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
192
+ audio: Optional[AudioInput] = None,
193
+ videos: Optional[VideoInput] = None,
194
+ **kwargs: Unpack[ColQwen3ProcessorKwargs],
195
+ ) -> BatchFeature:
196
+ output_kwargs = self._merge_kwargs(
197
+ ColQwen3ProcessorKwargs,
198
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
199
+ **kwargs,
200
+ )
201
+ suffix = output_kwargs["text_kwargs"].pop("suffix", None)
202
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
203
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
204
+
205
+ if images is not None and videos is not None:
206
+ raise ValueError("Provide only one of `images` or `videos`, not both.")
207
+
208
+ # Normalize text inputs
209
+ text_list: list[str] = []
210
+ if text is not None:
211
+ if isinstance(text, str):
212
+ text_list = [text]
213
+ elif isinstance(text, list):
214
+ if len(text) == 0 or not all(isinstance(t, (str, type(None))) for t in text):
215
+ raise ValueError("Text must be a string or a list of strings.")
216
+ text_list = [t or "" for t in text]
217
+ else:
218
+ raise ValueError("Text must be a string or a list of strings")
219
+
220
+ # Normalize image inputs
221
+ image_list: Optional[list[Any]] = None
222
+ if images is not None:
223
+ raw_images = images if isinstance(images, list) else [images]
224
+ image_list = []
225
+ for idx, img_item in enumerate(raw_images):
226
+ if img_item is None:
227
+ image_list.append([])
228
+ elif is_valid_image(img_item):
229
+ image_list.append([img_item])
230
+ elif isinstance(img_item, list):
231
+ if not img_item:
232
+ image_list.append([])
233
+ continue
234
+ for sub_idx, sub_img in enumerate(img_item):
235
+ if not is_valid_image(sub_img):
236
+ raise ValueError(f"Image at position {idx}[{sub_idx}] is not a valid image.")
237
+ image_list.append(list(img_item))
238
+ else:
239
+ raise ValueError("images must be an image, list of images or list of list of images")
240
+
241
+ # Normalize video inputs
242
+ video_list: Optional[list[Any]] = None
243
+ if videos is not None:
244
+ raw_videos = list(videos) if isinstance(videos, (list, tuple)) else [videos]
245
+ video_list = []
246
+ for idx, vid_item in enumerate(raw_videos):
247
+ if vid_item is None:
248
+ video_list.append([])
249
+ elif isinstance(vid_item, list):
250
+ video_list.append(list(vid_item))
251
+ else:
252
+ video_list.append([vid_item])
253
+
254
+ if image_list is None and video_list is None and not text_list:
255
+ raise ValueError("Either text, images or videos must be provided")
256
+
257
+ # Align text length with provided vision inputs when needed
258
+ if image_list is not None:
259
+ if not text_list:
260
+ text_list = [""] * len(image_list)
261
+ elif len(text_list) == 1 and len(image_list) > 1:
262
+ text_list = text_list * len(image_list)
263
+ elif len(text_list) != len(image_list):
264
+ raise ValueError("When providing both images and text, their lengths must match.")
265
+ num_items = len(image_list)
266
+ elif video_list is not None:
267
+ if not text_list:
268
+ text_list = [""] * len(video_list)
269
+ elif len(text_list) == 1 and len(video_list) > 1:
270
+ text_list = text_list * len(video_list)
271
+ elif len(text_list) != len(video_list):
272
+ raise ValueError("When providing both videos and text, their lengths must match.")
273
+ num_items = len(video_list)
274
+ else:
275
+ num_items = len(text_list)
276
+
277
+ if num_items == 0:
278
+ raise ValueError("Either text, images or videos must be provided")
279
+
280
+ prompts: list[str] = []
281
+ query_suffix = suffix if suffix is not None else self.query_augmentation_token * 10
282
+
283
+ for idx in range(num_items):
284
+ extra_text = (text_list[idx] if idx < len(text_list) else "") or ""
285
+ extra_text = extra_text.strip()
286
+ has_image = image_list is not None and len(image_list[idx]) > 0
287
+ has_video = video_list is not None and len(video_list[idx]) > 0
288
+ if has_image and has_video:
289
+ raise ValueError("Provide only one of `images` or `videos` per item.")
290
+
291
+ if has_image:
292
+ prompt = (
293
+ f"{self.visual_prompt_prefix} {extra_text}{self.visual_prompt_suffix}"
294
+ if extra_text
295
+ else f"{self.visual_prompt_prefix}{self.visual_prompt_suffix}"
296
+ )
297
+ prompts.append(prompt)
298
+ elif has_video:
299
+ prompt = (
300
+ f"{self.video_prompt_prefix} {extra_text}{self.video_prompt_suffix}"
301
+ if extra_text
302
+ else f"{self.video_prompt_prefix}{self.video_prompt_suffix}"
303
+ )
304
+ prompts.append(prompt)
305
+ else:
306
+ prompt = self.query_prefix + extra_text + query_suffix
307
+ prompts.append(prompt)
308
+
309
+ # Process images (excluding empty placeholders)
310
+ image_inputs: dict[str, Any] = {}
311
+ image_grid_thw = None
312
+ if image_list is not None:
313
+ normalized_images: list[list[Image.Image]] = []
314
+ for idx, img_group in enumerate(image_list):
315
+ converted_list: list[Image.Image] = []
316
+ for sub_idx, sub_img in enumerate(img_group):
317
+ if not is_valid_image(sub_img):
318
+ raise ValueError(f"Image at position {idx}[{sub_idx}] is not a valid image.")
319
+ converted_list.append(sub_img.convert("RGB") if hasattr(sub_img, "convert") else sub_img)
320
+ normalized_images.append(converted_list)
321
+
322
+ image_inputs = self.image_processor(images=normalized_images, **output_kwargs["images_kwargs"])
323
+ image_grid_thw = image_inputs["image_grid_thw"]
324
+
325
+ # Process videos (excluding empty placeholders)
326
+ videos_inputs: dict[str, Any] = {}
327
+ video_grid_thw = None
328
+ video_metadata = None
329
+ if video_list is not None:
330
+ videos_inputs = self.video_processor(videos=video_list, **output_kwargs["videos_kwargs"])
331
+ video_grid_thw = videos_inputs["video_grid_thw"]
332
+ if "return_metadata" not in output_kwargs["videos_kwargs"]:
333
+ video_metadata = videos_inputs.pop("video_metadata")
334
+ else:
335
+ video_metadata = videos_inputs["video_metadata"]
336
+
337
+ # Expand prompts to match the number of visual tokens
338
+ text_prompts = prompts.copy()
339
+ if image_grid_thw is not None:
340
+ merge_size = getattr(self.image_processor, "merge_size", None) or getattr(
341
+ self.image_processor, "spatial_merge_size", None
342
+ )
343
+ if merge_size is None:
344
+ raise ValueError("Qwen3VL image processor is missing `merge_size`/`spatial_merge_size`.")
345
+ merge_length = merge_size**2
346
+ index = 0
347
+ for i in range(len(text_prompts)):
348
+ while self.image_token in text_prompts[i]:
349
+ if index >= len(image_grid_thw):
350
+ raise ValueError("Number of image tokens does not match provided images.")
351
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
352
+ text_prompts[i] = text_prompts[i].replace(
353
+ self.image_token, "<|placeholder|>" * num_image_tokens, 1
354
+ )
355
+ index += 1
356
+ text_prompts[i] = text_prompts[i].replace("<|placeholder|>", self.image_token)
357
+
358
+ if video_grid_thw is not None:
359
+ merge_size = getattr(self.video_processor, "merge_size", None)
360
+ if merge_size is None:
361
+ raise ValueError("Qwen3VL video processor is missing `merge_size`.")
362
+ merge_length = merge_size**2
363
+ index = 0
364
+ for i in range(len(text_prompts)):
365
+ while self.video_token in text_prompts[i]:
366
+ if video_metadata is None or index >= len(video_metadata):
367
+ raise ValueError("Video metadata is required to build video prompts.")
368
+ metadata = video_metadata[index]
369
+ if metadata.fps is None:
370
+ logger.warning_once(
371
+ "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could "
372
+ "not be inferred. Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
373
+ )
374
+ metadata.fps = 24 if metadata.fps is None else metadata.fps
375
+
376
+ curr_timestamp = self._calculate_timestamps(
377
+ metadata.frames_indices, metadata.fps, self.video_processor.merge_size
378
+ )
379
+ frame_seqlen = int(video_grid_thw[index][1:].prod().item() // merge_length)
380
+ video_placeholder = ""
381
+ for frame_idx in range(int(video_grid_thw[index][0])):
382
+ curr_time = curr_timestamp[frame_idx]
383
+ video_placeholder += f"<{curr_time:.1f} seconds>"
384
+ video_placeholder += (
385
+ self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
386
+ )
387
+
388
+ if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text_prompts[i]:
389
+ text_prompts[i] = text_prompts[i].replace(
390
+ f"{self.vision_start_token}{self.video_token}{self.vision_end_token}",
391
+ video_placeholder,
392
+ 1,
393
+ )
394
+ else:
395
+ text_prompts[i] = text_prompts[i].replace(self.video_token, video_placeholder, 1)
396
+ index += 1
397
+
398
+ text_prompts[i] = text_prompts[i].replace("<|placeholder|>", self.video_token)
399
+
400
+ text_inputs = self.tokenizer(text_prompts, **output_kwargs["text_kwargs"])
401
+ self._check_special_mm_tokens(text_prompts, text_inputs, modalities=["image", "video"])
402
+
403
+ if return_mm_token_type_ids:
404
+ array_ids = np.array(text_inputs["input_ids"])
405
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
406
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
407
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
408
+
409
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
410
+
411
+ def process_images(
412
+ self,
413
+ images: List[Image.Image],
414
+ ) -> Union[BatchFeature, BatchEncoding]:
415
+ images = [image.convert("RGB") for image in images]
416
+ return self(images=images, padding="longest", return_tensors="pt")
417
+
418
+ def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
419
+ return self(text=texts, return_tensors="pt", padding="longest")
420
+
421
+
422
+ @staticmethod
423
+ def _split_batch_feature(batch_feature: BatchFeature) -> list[BatchFeature]:
424
+ # Split a batched BatchFeature into a list of per-item BatchFeatures.
425
+ length: Optional[int] = None
426
+ for value in batch_feature.values():
427
+ if hasattr(value, "__len__"):
428
+ try:
429
+ length = len(value)
430
+ except Exception:
431
+ continue
432
+ if length is not None:
433
+ break
434
+
435
+ if length is None:
436
+ return [batch_feature]
437
+
438
+ items: list[BatchFeature] = []
439
+ for idx in range(length):
440
+ data = {}
441
+ for key, value in batch_feature.items():
442
+ try:
443
+ data[key] = value[idx]
444
+ except Exception:
445
+ data[key] = value
446
+ items.append(BatchFeature(data=data))
447
+ return items
448
+
449
+ @staticmethod
450
+ def _merge_batch_features(features: list[BatchFeature]) -> BatchFeature:
451
+ if not features:
452
+ return BatchFeature()
453
+
454
+ all_keys = set()
455
+ for feat in features:
456
+ all_keys.update(feat.keys())
457
+
458
+ merged: dict[str, list[Any]] = {key: [] for key in all_keys}
459
+ for feat in features:
460
+ for key in all_keys:
461
+ merged[key].append(feat.get(key))
462
+
463
+ combined: dict[str, Any] = {}
464
+ for key, values in merged.items():
465
+ # Prefer stacking tensors so callers get batched tensors instead of lists
466
+ if all(isinstance(v, torch.Tensor) for v in values):
467
+ try:
468
+ combined[key] = torch.stack(values)
469
+ continue
470
+ except Exception:
471
+ # Fallback to list if shapes are incompatible for stacking
472
+ pass
473
+ combined[key] = values
474
+
475
+ return BatchFeature(data=combined)
476
+
477
+ def score_retrieval(
478
+ self,
479
+ qs: List[torch.Tensor],
480
+ ps: List[torch.Tensor],
481
+ score_batch_size: int = 128,
482
+ device: Optional[Union[str, torch.device]] = None,
483
+ **kwargs,
484
+ ) -> torch.Tensor:
485
+ return self.score_multi_vector(qs, ps, batch_size=score_batch_size, device=device, **kwargs)
486
+
487
+ @staticmethod
488
+ def score_single_vector(
489
+ qs: Union[torch.Tensor, List[torch.Tensor]],
490
+ ps: Union[torch.Tensor, List[torch.Tensor]],
491
+ device: Optional[Union[str, torch.device]] = None,
492
+ ) -> torch.Tensor:
493
+ """
494
+ Compute the dot product score for the given single-vector query and passage embeddings.
495
+ """
496
+ device = device or get_torch_device("auto")
497
+
498
+ if isinstance(qs, list) and isinstance(ps, list):
499
+ if len(qs) == 0:
500
+ raise ValueError("No queries provided")
501
+ if len(ps) == 0:
502
+ raise ValueError("No passages provided")
503
+
504
+ qs = torch.stack(qs).to(device)
505
+ ps = torch.stack(ps).to(device)
506
+ else:
507
+ qs = qs.to(device)
508
+ ps = ps.to(device)
509
+
510
+ scores = torch.einsum("bd,cd->bc", qs, ps)
511
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
512
+
513
+ scores = scores.to(torch.float32)
514
+ return scores
515
+
516
+ @staticmethod
517
+ def score_multi_vector(
518
+ qs: Union[torch.Tensor, List[torch.Tensor]],
519
+ ps: Union[torch.Tensor, List[torch.Tensor]],
520
+ batch_size: int = 128,
521
+ device: Optional[Union[str, torch.device]] = None,
522
+ ) -> torch.Tensor:
523
+ """
524
+ Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
525
+ query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
526
+ image of a document page.
527
+
528
+ Because the embedding tensors are multi-vector and can thus have different shapes, they
529
+ should be fed as:
530
+ (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
531
+ (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
532
+ obtained by padding the list of tensors.
533
+
534
+ Args:
535
+ qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
536
+ ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
537
+ batch_size (`int`, *optional*): Batch size for computing scores.
538
+ device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not
539
+ provided, uses `get_torch_device("auto")`.
540
+
541
+ Returns:
542
+ `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
543
+ tensor is saved on the "cpu" device.
544
+ """
545
+ device = device or get_torch_device("auto")
546
+
547
+ if len(qs) == 0:
548
+ raise ValueError("No queries provided")
549
+ if len(ps) == 0:
550
+ raise ValueError("No passages provided")
551
+
552
+ scores_list: List[torch.Tensor] = []
553
+
554
+ for i in range(0, len(qs), batch_size):
555
+ scores_batch = []
556
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
557
+ device
558
+ )
559
+ for j in range(0, len(ps), batch_size):
560
+ ps_batch = torch.nn.utils.rnn.pad_sequence(
561
+ ps[j : j + batch_size], batch_first=True, padding_value=0
562
+ ).to(device)
563
+ scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
564
+ scores_batch = torch.cat(scores_batch, dim=1).cpu()
565
+ scores_list.append(scores_batch)
566
+
567
+ scores = torch.cat(scores_list, dim=0)
568
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
569
+
570
+ scores = scores.to(torch.float32)
571
+ return scores
572
+
573
+ @staticmethod
574
+ def get_topk_plaid(
575
+ qs: Union[torch.Tensor, List[torch.Tensor]],
576
+ plaid_index: "search.FastPlaid",
577
+ k: int = 10,
578
+ batch_size: int = 128,
579
+ device: Optional[Union[str, torch.device]] = None,
580
+ ) -> torch.Tensor:
581
+ """
582
+ Experimental: Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
583
+ query embeddings (`qs`) and passage embeddings endoded in a plaid index. For ColPali, a passage is the
584
+ image of a document page.
585
+ """
586
+ device = device or get_torch_device("auto")
587
+
588
+ if len(qs) == 0:
589
+ raise ValueError("No queries provided")
590
+
591
+ scores_list: List[torch.Tensor] = []
592
+
593
+ for i in range(0, len(qs), batch_size):
594
+ scores_batch = []
595
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
596
+ device
597
+ )
598
+ scores_batch = plaid_index.search(
599
+ queries_embeddings=qs_batch.to(torch.float32),
600
+ top_k=k,
601
+ )
602
+ scores_list.append(scores_batch)
603
+
604
+ return scores_list
605
+
606
+ @staticmethod
607
+ def create_plaid_index(
608
+ ps: Union[torch.Tensor, List[torch.Tensor]],
609
+ device: Optional[Union[str, torch.device]] = None,
610
+ ) -> torch.Tensor:
611
+ """
612
+ Experimental: Create a FastPlaid index from the given passage embeddings.
613
+ Args:
614
+ ps (`Union[torch.Tensor, List[torch.Tensor]]`): Passage embeddings. Should be a list of tensors,
615
+ where each tensor is of shape (sequence_length_i, embedding_dim).
616
+ device (`Optional[Union[str, torch.device]]`, *optional*): Device to use for computation. If not
617
+ provided, uses `get_torch_device("auto")`.
618
+ """
619
+ if not importlib.util.find_spec("fast_plaid"):
620
+ raise ImportError("FastPlaid is not installed. Please install it with `pip install fast-plaid`.")
621
+
622
+ fast_plaid_index = search.FastPlaid(index="index")
623
+ device = device or get_torch_device("auto")
624
+ fast_plaid_index.create(documents_embeddings=[d.to(device).to(torch.float32) for d in ps])
625
+ return fast_plaid_index
626
+
627
+ def get_n_patches(
628
+ self,
629
+ image_size: Tuple[int, int],
630
+ spatial_merge_size: int,
631
+ ) -> Tuple[int, int]:
632
+ """
633
+ Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
634
+ size (height, width) with the given patch size.
635
+
636
+ The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
637
+ as a `Qwen2VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
638
+ """
639
+ patch_size = self.image_processor.patch_size
640
+
641
+ height_new, width_new = smart_resize(
642
+ width=image_size[0],
643
+ height=image_size[1],
644
+ factor=patch_size * self.image_processor.merge_size,
645
+ min_pixels=self.image_processor.size["shortest_edge"],
646
+ max_pixels=self.image_processor.size["longest_edge"],
647
+ )
648
+
649
+ n_patches_x = width_new // patch_size // spatial_merge_size
650
+ n_patches_y = height_new // patch_size // spatial_merge_size
651
+
652
+ return n_patches_x, n_patches_y
653
+
654
+ def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
655
+ return batch_images.input_ids == self.image_token_id
656
+
657
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
658
+ vision_data = {}
659
+ if image_sizes is not None:
660
+ images_kwargs = ColQwen3ProcessorKwargs._defaults.get("images_kwargs", {})
661
+ images_kwargs.update(kwargs)
662
+ merge_size = images_kwargs.get("merge_size", None) or getattr(
663
+ self.image_processor, "merge_size", None
664
+ ) or getattr(self.image_processor, "spatial_merge_size", None)
665
+ if merge_size is None:
666
+ raise ValueError("Qwen3VL image processor is missing `merge_size`/`spatial_merge_size`.")
667
+
668
+ num_image_patches = [
669
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
670
+ for image_size in image_sizes
671
+ ]
672
+ num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
673
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
674
+
675
+ video_sizes = kwargs.pop("video_sizes", None)
676
+ if video_sizes is not None:
677
+ videos_kwargs = ColQwen3ProcessorKwargs._defaults.get("videos_kwargs", {})
678
+ videos_kwargs.update(kwargs)
679
+ merge_size = videos_kwargs.get("merge_size", None) or getattr(self.video_processor, "merge_size", None)
680
+ if merge_size is None:
681
+ raise ValueError("Qwen3VL video processor is missing `merge_size`.")
682
+
683
+ num_video_patches = [
684
+ self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) for video_size in video_sizes
685
+ ]
686
+ num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
687
+ vision_data.update({"num_video_tokens": num_video_tokens, "num_video_patches": num_video_patches})
688
+
689
+ return MultiModalData(**vision_data)
690
+
691
+ @property
692
+ def model_input_names(self) -> list[str]:
693
+ return [
694
+ "input_ids",
695
+ "attention_mask",
696
+ "pixel_values",
697
+ "image_grid_thw",
698
+ "pixel_values_videos",
699
+ "video_grid_thw",
700
+ ]
701
+
702
+ @property
703
+ def query_augmentation_token(self) -> str:
704
+ return self.tokenizer.pad_token
705
+
706
+ def get_video_mask(self, batch_videos: BatchFeature) -> torch.Tensor:
707
+ return batch_videos.input_ids == self.video_token_id
708
+
709
+ def _calculate_timestamps(
710
+ self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2
711
+ ) -> list[float]:
712
+ if not isinstance(indices, list):
713
+ indices = indices.tolist()
714
+ if len(indices) % merge_size != 0:
715
+ indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size))
716
+ timestamps = [idx / video_fps for idx in indices]
717
+ timestamps = [
718
+ (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size)
719
+ ]
720
+ return timestamps
721
+
722
+
723
+ __all__ = ["ColQwen3Processor", "ColQwen3ProcessorKwargs"]
processor_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_colqwen3.ColQwen3Processor"
4
+ },
5
+ "processor_class": "ColQwen3Processor",
6
+ "query_prefix": "",
7
+ "video_prompt_prefix": "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>Describe the video.",
8
+ "video_prompt_suffix": "<|im_end|><|endoftext|>",
9
+ "visual_prompt_prefix": "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.",
10
+ "visual_prompt_suffix": "<|im_end|><|endoftext|>"
11
+ }
quantization_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bits": 4,
3
+ "group_size": 128,
4
+ "sym": true,
5
+ "data_type": "int",
6
+ "seqlen": 1024,
7
+ "batch_size": 80,
8
+ "iters": 1000,
9
+ "nsamples": 560,
10
+ "autoround_version": "0.9.2",
11
+ "block_name_to_quantize": "vlm.model.language_model.layers",
12
+ "quant_method": "auto-round",
13
+ "packing_format": "auto_round:auto_gptq"
14
+ }
quantization_metadata.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "quantization_method": "auto-round",
3
+ "scheme": "W4A16",
4
+ "bits": 4,
5
+ "group_size": 128,
6
+ "sym": true,
7
+ "data_type": "int",
8
+ "act_bits": 16,
9
+ "iters": 1000,
10
+ "nsamples": 500,
11
+ "calibration_dataset": "NeelNanda/pile-10k (AutoRound default)",
12
+ "calibration_type": "text-only (language model only)",
13
+ "quantized_layers": 252,
14
+ "fp16_layers": 105,
15
+ "original_model": "TomoroAI/tomoro-colqwen3-embed-4b",
16
+ "note": "Vision encoder kept in FP16 (not quantized). Text-only calibration is appropriate since only language_model is quantized."
17
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer_config.json ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "auto_map": {
230
+ "AutoProcessor": "processing_colqwen3.ColQwen3Processor"
231
+ },
232
+ "bos_token": null,
233
+ "clean_up_tokenization_spaces": false,
234
+ "eos_token": "<|im_end|>",
235
+ "errors": "replace",
236
+ "extra_special_tokens": {},
237
+ "model_max_length": 262144,
238
+ "pad_token": "<|endoftext|>",
239
+ "processor_class": "ColQwen3Processor",
240
+ "split_special_tokens": false,
241
+ "tokenizer_class": "Qwen2Tokenizer",
242
+ "unk_token": null
243
+ }
video_preprocessor_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_colqwen3.ColQwen3Processor"
4
+ },
5
+ "crop_size": null,
6
+ "data_format": "channels_first",
7
+ "default_to_square": true,
8
+ "device": null,
9
+ "do_center_crop": null,
10
+ "do_convert_rgb": true,
11
+ "do_normalize": true,
12
+ "do_rescale": true,
13
+ "do_resize": true,
14
+ "do_sample_frames": true,
15
+ "fps": 2,
16
+ "image_mean": [
17
+ 0.5,
18
+ 0.5,
19
+ 0.5
20
+ ],
21
+ "image_std": [
22
+ 0.5,
23
+ 0.5,
24
+ 0.5
25
+ ],
26
+ "input_data_format": null,
27
+ "max_frames": 768,
28
+ "max_pixels": 2621440,
29
+ "merge_size": 2,
30
+ "min_frames": 4,
31
+ "num_frames": null,
32
+ "pad_size": null,
33
+ "patch_size": 16,
34
+ "processor_class": "ColQwen3Processor",
35
+ "resample": 3,
36
+ "rescale_factor": 0.00392156862745098,
37
+ "return_metadata": false,
38
+ "size": {
39
+ "longest_edge": 2621440,
40
+ "shortest_edge": 4096
41
+ },
42
+ "temporal_patch_size": 2,
43
+ "video_metadata": null,
44
+ "video_processor_type": "Qwen3VLVideoProcessor"
45
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff