Depth Estimation
Transformers
Safetensors
tipsv2_dpt
feature-extraction
vision
surface-normals
semantic-segmentation
dense-prediction
custom_code
Instructions to use google/tipsv2-g14-dpt with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use google/tipsv2-g14-dpt with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("depth-estimation", model="google/tipsv2-g14-dpt", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("google/tipsv2-g14-dpt", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright 2025 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """DPT (Dense Prediction Transformer) depth head in PyTorch. | |
| Ported from the Scenic/Flax implementation at: | |
| research/vision/scene_understanding/imsight/modules/dpt.py | |
| scenic/projects/dense_features/models/decoders.py | |
| Architecture: | |
| ReassembleBlocks → 4×Conv3x3 → 4×FeatureFusionBlock → project → DepthHead | |
| """ | |
| import io | |
| import os | |
| import urllib.request | |
| import zipfile | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| # ── Building blocks ───────────────────────────────────────────────────────── | |
| class PreActResidualConvUnit(nn.Module): | |
| """Pre-activation residual convolution unit.""" | |
| def __init__(self, features: int): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(features, features, 3, padding=1, bias=False) | |
| self.conv2 = nn.Conv2d(features, features, 3, padding=1, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = F.relu(x) | |
| x = self.conv1(x) | |
| x = F.relu(x) | |
| x = self.conv2(x) | |
| return x + residual | |
| class FeatureFusionBlock(nn.Module): | |
| """Fuses features with optional residual input, then upsamples 2×.""" | |
| def __init__(self, features: int, has_residual: bool = False, | |
| expand: bool = False): | |
| super().__init__() | |
| self.has_residual = has_residual | |
| if has_residual: | |
| self.residual_unit = PreActResidualConvUnit(features) | |
| self.main_unit = PreActResidualConvUnit(features) | |
| out_features = features // 2 if expand else features | |
| self.out_conv = nn.Conv2d(features, out_features, 1, bias=True) | |
| def forward(self, x: torch.Tensor, | |
| residual: torch.Tensor = None) -> torch.Tensor: | |
| if self.has_residual and residual is not None: | |
| if residual.shape != x.shape: | |
| residual = F.interpolate( | |
| residual, size=x.shape[2:], mode="bilinear", | |
| align_corners=False) | |
| residual = self.residual_unit(residual) | |
| x = x + residual | |
| x = self.main_unit(x) | |
| # Upsample 2× with align_corners=True (matches Scenic reference) | |
| x = F.interpolate(x, scale_factor=2, mode="bilinear", | |
| align_corners=True) | |
| x = self.out_conv(x) | |
| return x | |
| class ReassembleBlocks(nn.Module): | |
| """Projects and resizes intermediate ViT features to different scales.""" | |
| def __init__(self, input_embed_dim: int = 1024, | |
| out_channels: tuple = (128, 256, 512, 1024), | |
| readout_type: str = "project"): | |
| super().__init__() | |
| self.readout_type = readout_type | |
| # 1×1 conv to project to per-level channels | |
| self.out_projections = nn.ModuleList([ | |
| nn.Conv2d(input_embed_dim, ch, 1) for ch in out_channels | |
| ]) | |
| # Spatial resize layers: 4× up, 2× up, identity, 2× down | |
| self.resize_layers = nn.ModuleList([ | |
| nn.ConvTranspose2d(out_channels[0], out_channels[0], | |
| kernel_size=4, stride=4, padding=0), | |
| nn.ConvTranspose2d(out_channels[1], out_channels[1], | |
| kernel_size=2, stride=2, padding=0), | |
| nn.Identity(), | |
| nn.Conv2d(out_channels[3], out_channels[3], 3, stride=2, | |
| padding=1), | |
| ]) | |
| # Readout projection (concatenate cls_token with patch features) | |
| if readout_type == "project": | |
| self.readout_projects = nn.ModuleList([ | |
| nn.Linear(2 * input_embed_dim, input_embed_dim) | |
| for _ in out_channels | |
| ]) | |
| def forward(self, features): | |
| """Process list of (cls_token, spatial_features) tuples. | |
| Args: | |
| features: list of (cls_token [B,D], patch_feats [B,D,H,W]) | |
| Returns: | |
| list of tensors at different scales. | |
| """ | |
| out = [] | |
| for i, (cls_token, x) in enumerate(features): | |
| B, D, H, W = x.shape | |
| if self.readout_type == "project": | |
| # Flatten spatial → (B, HW, D) | |
| x_flat = x.flatten(2).transpose(1, 2) | |
| # Expand cls_token → (B, HW, D) | |
| readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1) | |
| # Concat + project + GELU | |
| x_cat = torch.cat([x_flat, readout], dim=-1) | |
| x_proj = F.gelu(self.readout_projects[i](x_cat)) | |
| # Reshape back to spatial | |
| x = x_proj.transpose(1, 2).reshape(B, D, H, W) | |
| # 1×1 projection | |
| x = self.out_projections[i](x) | |
| # Spatial resize | |
| x = self.resize_layers[i](x) | |
| out.append(x) | |
| return out | |
| class DPTDepthHead(nn.Module): | |
| """Full DPT head + depth classification decoder. | |
| Takes 4 intermediate ViT features and produces a depth map. | |
| """ | |
| def __init__(self, input_embed_dim: int = 1024, | |
| channels: int = 256, | |
| post_process_channels: tuple = (128, 256, 512, 1024), | |
| readout_type: str = "project", | |
| num_depth_bins: int = 256, | |
| min_depth: float = 1e-3, | |
| max_depth: float = 10.0): | |
| super().__init__() | |
| self.num_depth_bins = num_depth_bins | |
| self.min_depth = min_depth | |
| self.max_depth = max_depth | |
| # Reassemble: project + resize | |
| self.reassemble = ReassembleBlocks( | |
| input_embed_dim=input_embed_dim, | |
| out_channels=post_process_channels, | |
| readout_type=readout_type, | |
| ) | |
| # 3×3 convs to map each level to `channels` | |
| self.convs = nn.ModuleList([ | |
| nn.Conv2d(ch, channels, 3, padding=1, bias=False) | |
| for ch in post_process_channels | |
| ]) | |
| # Fusion blocks: first has no residual, rest have residual | |
| self.fusion_blocks = nn.ModuleList([ | |
| FeatureFusionBlock(channels, has_residual=False), | |
| FeatureFusionBlock(channels, has_residual=True), | |
| FeatureFusionBlock(channels, has_residual=True), | |
| FeatureFusionBlock(channels, has_residual=True), | |
| ]) | |
| # Final projection | |
| self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) | |
| # Depth classification head (Dense layer) | |
| self.depth_head = nn.Linear(channels, num_depth_bins) | |
| def forward(self, intermediate_features, image_size=None): | |
| """Run DPT depth prediction. | |
| Args: | |
| intermediate_features: list of 4 (cls_token, patch_feats) tuples | |
| image_size: (H, W) to resize output to, or None | |
| Returns: | |
| depth map tensor (B, 1, H, W) | |
| """ | |
| # Reassemble | |
| x = self.reassemble(intermediate_features) | |
| # 3×3 conv per level | |
| x = [self.convs[i](feat) for i, feat in enumerate(x)] | |
| # Fuse bottom-up: start from deepest (x[-1]) | |
| out = self.fusion_blocks[0](x[-1]) | |
| for i in range(1, 4): | |
| out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) | |
| # Project | |
| out = self.project(out) | |
| out = F.relu(out) | |
| # Depth classification | |
| # out: (B, C, H, W) → (B, H, W, C) | |
| out = out.permute(0, 2, 3, 1) | |
| out = self.depth_head(out) # (B, H, W, num_bins) | |
| # Classification-based depth prediction | |
| bin_centers = torch.linspace( | |
| self.min_depth, self.max_depth, self.num_depth_bins, | |
| device=out.device) | |
| out = F.relu(out) + self.min_depth | |
| out_norm = out / out.sum(dim=-1, keepdim=True) | |
| depth = torch.einsum("bhwn,n->bhw", out_norm, bin_centers) | |
| depth = depth.unsqueeze(1) # (B, 1, H, W) | |
| # Resize to original image size | |
| if image_size is not None: | |
| depth = F.interpolate(depth, size=image_size, mode="bilinear", | |
| align_corners=False) | |
| return depth | |
| class DPTNormalsHead(nn.Module): | |
| """Full DPT head + surface normals decoder. | |
| Takes 4 intermediate ViT features and produces a normal map. | |
| """ | |
| def __init__(self, input_embed_dim: int = 1024, | |
| channels: int = 256, | |
| post_process_channels: tuple = (128, 256, 512, 1024), | |
| readout_type: str = "project"): | |
| super().__init__() | |
| # Reassemble: project + resize | |
| self.reassemble = ReassembleBlocks( | |
| input_embed_dim=input_embed_dim, | |
| out_channels=post_process_channels, | |
| readout_type=readout_type, | |
| ) | |
| # 3×3 convs to map each level to `channels` | |
| self.convs = nn.ModuleList([ | |
| nn.Conv2d(ch, channels, 3, padding=1, bias=False) | |
| for ch in post_process_channels | |
| ]) | |
| # Fusion blocks: first has no residual, rest have residual | |
| self.fusion_blocks = nn.ModuleList([ | |
| FeatureFusionBlock(channels, has_residual=False), | |
| FeatureFusionBlock(channels, has_residual=True), | |
| FeatureFusionBlock(channels, has_residual=True), | |
| FeatureFusionBlock(channels, has_residual=True), | |
| ]) | |
| # Final projection | |
| self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) | |
| # Normals head (Dense layer) | |
| self.normals_head = nn.Linear(channels, 3) | |
| def forward(self, intermediate_features, image_size=None): | |
| """Run DPT normals prediction. | |
| Args: | |
| intermediate_features: list of 4 (cls_token, patch_feats) tuples | |
| image_size: (H, W) to resize output to, or None | |
| Returns: | |
| normal map tensor (B, 3, H, W) | |
| """ | |
| # Reassemble | |
| x = self.reassemble(intermediate_features) | |
| # 3×3 conv per level | |
| x = [self.convs[i](feat) for i, feat in enumerate(x)] | |
| # Fuse bottom-up: start from deepest (x[-1]) | |
| out = self.fusion_blocks[0](x[-1]) | |
| for i in range(1, 4): | |
| out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) | |
| # Project | |
| out = self.project(out) | |
| # Normals head | |
| # out: (B, C, H, W) → (B, H, W, C) | |
| out = out.permute(0, 2, 3, 1) | |
| out = self.normals_head(out) # (B, H, W, 3) | |
| # Normalize to unit length | |
| out = F.normalize(out, p=2, dim=-1) | |
| # Resize to original image size | |
| if image_size is not None: | |
| # PyTorch interpolate expects (B, C, H, W) | |
| out = out.permute(0, 3, 1, 2) | |
| out = F.interpolate(out, size=image_size, mode="bilinear", | |
| align_corners=False) | |
| else: | |
| out = out.permute(0, 3, 1, 2) | |
| return out | |
| class DPTSegmentationHead(nn.Module): | |
| """Full DPT head + segmentation decoder. | |
| Takes 4 intermediate ViT features and produces a segmentation map. | |
| """ | |
| def __init__(self, input_embed_dim: int = 1024, | |
| channels: int = 256, | |
| post_process_channels: tuple = (128, 256, 512, 1024), | |
| readout_type: str = "project", | |
| num_classes: int = 150): | |
| super().__init__() | |
| # Reassemble: project + resize | |
| self.reassemble = ReassembleBlocks( | |
| input_embed_dim=input_embed_dim, | |
| out_channels=post_process_channels, | |
| readout_type=readout_type, | |
| ) | |
| # 3×3 convs to map each level to `channels` | |
| self.convs = nn.ModuleList([ | |
| nn.Conv2d(ch, channels, 3, padding=1, bias=False) | |
| for ch in post_process_channels | |
| ]) | |
| # Fusion blocks: first has no residual, rest have residual | |
| self.fusion_blocks = nn.ModuleList([ | |
| FeatureFusionBlock(channels, has_residual=False), | |
| FeatureFusionBlock(channels, has_residual=True), | |
| FeatureFusionBlock(channels, has_residual=True), | |
| FeatureFusionBlock(channels, has_residual=True), | |
| ]) | |
| # Final projection | |
| self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) | |
| # Segmentation head (Dense layer) | |
| self.segmentation_head = nn.Linear(channels, num_classes) | |
| def forward(self, intermediate_features, image_size=None): | |
| """Run DPT segmentation prediction. | |
| Args: | |
| intermediate_features: list of 4 (cls_token, patch_feats) tuples | |
| image_size: (H, W) to resize output to, or None | |
| Returns: | |
| segmentation map tensor (B, num_classes, H, W) | |
| """ | |
| # Reassemble | |
| x = self.reassemble(intermediate_features) | |
| # 3×3 conv per level | |
| x = [self.convs[i](feat) for i, feat in enumerate(x)] | |
| # Fuse bottom-up: start from deepest (x[-1]) | |
| out = self.fusion_blocks[0](x[-1]) | |
| for i in range(1, 4): | |
| out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) | |
| # Project | |
| out = self.project(out) | |
| # Segmentation head | |
| # out: (B, C, H, W) → (B, H, W, C) | |
| out = out.permute(0, 2, 3, 1) | |
| out = self.segmentation_head(out) # (B, H, W, num_classes) | |
| # Resize to original image size | |
| if image_size is not None: | |
| # PyTorch interpolate expects (B, C, H, W) | |
| out = out.permute(0, 3, 1, 2) | |
| out = F.interpolate(out, size=image_size, mode="bilinear", | |
| align_corners=False) | |
| else: | |
| out = out.permute(0, 3, 1, 2) | |
| return out | |
| # ── Weight loading from Scenic/Flax checkpoint ───────────────────────────── | |
| def _load_npy_from_zip(zf, name): | |
| """Load a single .npy array from a zipfile.""" | |
| with zf.open(name) as f: | |
| return np.load(io.BytesIO(f.read())) | |
| def _conv_kernel_flax_to_torch(w): | |
| """Convert Flax conv kernel (H,W,Cin,Cout) → PyTorch (Cout,Cin,H,W).""" | |
| return torch.from_numpy(w.transpose(3, 2, 0, 1).copy()) | |
| def _conv_transpose_kernel_flax_to_torch(w): | |
| """Convert Flax ConvTranspose kernel (H,W,Cin,Cout) → PyTorch (Cin,Cout,H,W).""" | |
| return torch.from_numpy(w.transpose(2, 3, 0, 1).copy()) | |
| def _linear_kernel_flax_to_torch(w): | |
| """Convert Flax Dense kernel (in,out) → PyTorch Linear (out,in).""" | |
| return torch.from_numpy(w.T.copy()) | |
| def _bias(w): | |
| return torch.from_numpy(w.copy()) | |
| def load_dpt_weights(model: DPTDepthHead, zip_path: str): | |
| """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model.""" | |
| zf = zipfile.ZipFile(zip_path, "r") | |
| npy = lambda name: _load_npy_from_zip(zf, name) | |
| sd = {} | |
| prefix = "decoder/dpt/" | |
| # --- ReassembleBlocks --- | |
| for i in range(4): | |
| # out_projections (Conv2d 1×1) | |
| sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy")) | |
| sd[f"reassemble.out_projections.{i}.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy")) | |
| # readout_projects (Linear) | |
| sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy")) | |
| sd[f"reassemble.readout_projects.{i}.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy")) | |
| # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv | |
| sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy")) | |
| sd["reassemble.resize_layers.0.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy")) | |
| sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy")) | |
| sd["reassemble.resize_layers.1.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy")) | |
| # resize_layers_2 = Identity (no weights) | |
| sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy")) | |
| sd["reassemble.resize_layers.3.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy")) | |
| # --- Convs (3×3, no bias) --- | |
| for i in range(4): | |
| sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}convs_{i}/kernel.npy")) | |
| # --- Fusion blocks --- | |
| for i in range(4): | |
| fb = f"{prefix}fusion_blocks_{i}/" | |
| if i == 0: | |
| # No residual unit, only 1 PreActResidualConvUnit | |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) | |
| else: | |
| # Residual unit (index 0) + main unit (index 1) | |
| sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy")) | |
| # out_conv (Conv2d 1×1) | |
| sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}Conv_0/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias( | |
| npy(f"{fb}Conv_0/bias.npy")) | |
| # --- Project --- | |
| sd["project.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}project/kernel.npy")) | |
| sd["project.bias"] = _bias( | |
| npy(f"{prefix}project/bias.npy")) | |
| # --- Depth classification head --- | |
| sd["depth_head.weight"] = _linear_kernel_flax_to_torch( | |
| npy("decoder/pixel_depth_classif/kernel.npy")) | |
| sd["depth_head.bias"] = _bias( | |
| npy("decoder/pixel_depth_classif/bias.npy")) | |
| zf.close() | |
| # Load into model | |
| missing, unexpected = model.load_state_dict(sd, strict=True) | |
| if missing: | |
| print(f"WARNING: Missing keys: {missing}") | |
| if unexpected: | |
| print(f"WARNING: Unexpected keys: {unexpected}") | |
| print(f"Loaded DPT depth head weights ({len(sd)} tensors)") | |
| return model | |
| def load_normals_weights(model: DPTNormalsHead, zip_path: str): | |
| """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model.""" | |
| zf = zipfile.ZipFile(zip_path, "r") | |
| npy = lambda name: _load_npy_from_zip(zf, name) | |
| sd = {} | |
| prefix = "decoder/dpt/" | |
| # --- ReassembleBlocks --- | |
| for i in range(4): | |
| # out_projections (Conv2d 1×1) | |
| sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy")) | |
| sd[f"reassemble.out_projections.{i}.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy")) | |
| # readout_projects (Linear) | |
| sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy")) | |
| sd[f"reassemble.readout_projects.{i}.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy")) | |
| # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv | |
| sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy")) | |
| sd["reassemble.resize_layers.0.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy")) | |
| sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy")) | |
| sd["reassemble.resize_layers.1.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy")) | |
| # resize_layers_2 = Identity (no weights) | |
| sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy")) | |
| sd["reassemble.resize_layers.3.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy")) | |
| # --- Convs (3×3, no bias) --- | |
| for i in range(4): | |
| sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}convs_{i}/kernel.npy")) | |
| # --- Fusion blocks --- | |
| for i in range(4): | |
| fb = f"{prefix}fusion_blocks_{i}/" | |
| if i == 0: | |
| # No residual unit, only 1 PreActResidualConvUnit | |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) | |
| else: | |
| # Residual unit (index 0) + main unit (index 1) | |
| sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy")) | |
| # out_conv (Conv2d 1×1) | |
| sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}Conv_0/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias( | |
| npy(f"{fb}Conv_0/bias.npy")) | |
| # --- Project --- | |
| sd["project.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}project/kernel.npy")) | |
| sd["project.bias"] = _bias( | |
| npy(f"{prefix}project/bias.npy")) | |
| # --- Normals head --- | |
| sd["normals_head.weight"] = _linear_kernel_flax_to_torch( | |
| npy("decoder/pixel_normals/kernel.npy")) | |
| sd["normals_head.bias"] = _bias( | |
| npy("decoder/pixel_normals/bias.npy")) | |
| zf.close() | |
| # Load into model | |
| missing, unexpected = model.load_state_dict(sd, strict=True) | |
| if missing: | |
| print(f"WARNING: Missing keys: {missing}") | |
| if unexpected: | |
| print(f"WARNING: Unexpected keys: {unexpected}") | |
| print(f"Loaded DPT normals head weights ({len(sd)} tensors)") | |
| return model | |
| def load_segmentation_weights(model: DPTSegmentationHead, zip_path: str): | |
| """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model.""" | |
| zf = zipfile.ZipFile(zip_path, "r") | |
| npy = lambda name: _load_npy_from_zip(zf, name) | |
| sd = {} | |
| prefix = "decoder/dpt/" | |
| # --- ReassembleBlocks --- | |
| for i in range(4): | |
| # out_projections (Conv2d 1×1) | |
| sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy")) | |
| sd[f"reassemble.out_projections.{i}.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy")) | |
| # readout_projects (Linear) | |
| sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy")) | |
| sd[f"reassemble.readout_projects.{i}.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy")) | |
| # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv | |
| sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy")) | |
| sd["reassemble.resize_layers.0.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy")) | |
| sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy")) | |
| sd["reassemble.resize_layers.1.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy")) | |
| # resize_layers_2 = Identity (no weights) | |
| sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy")) | |
| sd["reassemble.resize_layers.3.bias"] = _bias( | |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy")) | |
| # --- Convs (3×3, no bias) --- | |
| for i in range(4): | |
| sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}convs_{i}/kernel.npy")) | |
| # --- Fusion blocks --- | |
| for i in range(4): | |
| fb = f"{prefix}fusion_blocks_{i}/" | |
| if i == 0: | |
| # No residual unit, only 1 PreActResidualConvUnit | |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) | |
| else: | |
| # Residual unit (index 0) + main unit (index 1) | |
| sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy")) | |
| # out_conv (Conv2d 1×1) | |
| sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{fb}Conv_0/kernel.npy")) | |
| sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias( | |
| npy(f"{fb}Conv_0/bias.npy")) | |
| # --- Project --- | |
| sd["project.weight"] = _conv_kernel_flax_to_torch( | |
| npy(f"{prefix}project/kernel.npy")) | |
| sd["project.bias"] = _bias( | |
| npy(f"{prefix}project/bias.npy")) | |
| # --- Segmentation head --- | |
| sd["segmentation_head.weight"] = _linear_kernel_flax_to_torch( | |
| npy("decoder/pixel_segmentation/kernel.npy")) | |
| sd["segmentation_head.bias"] = _bias( | |
| npy("decoder/pixel_segmentation/bias.npy")) | |
| zf.close() | |
| # Load into model | |
| missing, unexpected = model.load_state_dict(sd, strict=True) | |
| if missing: | |
| print(f"WARNING: Missing keys: {missing}") | |
| if unexpected: | |
| print(f"WARNING: Unexpected keys: {unexpected}") | |
| print(f"Loaded DPT segmentation head weights ({len(sd)} tensors)") | |
| return model | |