sd_vae_01 / README.md
codemichaeld's picture
Upload README.md with huggingface_hub
e5a6f18 verified
---
library_name: diffusers
tags:
- fp8
- safetensors
- precision-recovery
- mixed-method
- converted-by-gradio
---
# FP8 Model with Per-Tensor Precision Recovery
- **Source**: `https://huggingface.co/stabilityai/sd-vae-ft-mse`
- **Original File**: `diffusion_pytorch_model.safetensors`
- **FP8 Format**: `E5M2`
- **FP8 File**: `diffusion_pytorch_model-fp8-e5m2.safetensors`
- **Recovery File**: `diffusion_pytorch_model-recovery.safetensors`
## Recovery Rules Used
```json
[
{
"key_pattern": "vae",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "encoder",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "decoder",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "text",
"dim": 2,
"min_size": 10000,
"method": "lora",
"rank": 64
},
{
"key_pattern": "emb",
"dim": 2,
"min_size": 10000,
"method": "lora",
"rank": 64
},
{
"key_pattern": "attn",
"dim": 2,
"min_size": 10000,
"method": "lora",
"rank": 128
},
{
"key_pattern": "conv",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "resnet",
"dim": 4,
"method": "diff"
},
{
"key_pattern": "all",
"method": "none"
}
]
```
## Usage (Inference)
```python
from safetensors.torch import load_file
import torch
# Load FP8 model
fp8_state = load_file("diffusion_pytorch_model-fp8-e5m2.safetensors")
# Load recovery weights if available
recovery_state = load_file("diffusion_pytorch_model-recovery.safetensors") if "diffusion_pytorch_model-recovery.safetensors" and os.path.exists("diffusion_pytorch_model-recovery.safetensors") else {}
# Reconstruct high-precision weights
reconstructed = {}
for key in fp8_state:
fp8_weight = fp8_state[key].to(torch.float32) # Convert to float32 for computation
# Apply LoRA recovery if available
lora_a_key = f"lora_A.{key}"
lora_b_key = f"lora_B.{key}"
if lora_a_key in recovery_state and lora_b_key in recovery_state:
A = recovery_state[lora_a_key].to(torch.float32)
B = recovery_state[lora_b_key].to(torch.float32)
# Reconstruct the low-rank approximation
lora_weight = B @ A
fp8_weight = fp8_weight + lora_weight
# Apply difference recovery if available
diff_key = f"diff.{key}"
if diff_key in recovery_state:
diff = recovery_state[diff_key].to(torch.float32)
fp8_weight = fp8_weight + diff
reconstructed[key] = fp8_weight
# Use reconstructed weights in your model
model.load_state_dict(reconstructed)
```
> **Note**: For best results, use the same recovery configuration during inference as was used during extraction.
> Requires PyTorch ≥ 2.1 for FP8 support.
## Statistics
- **Total layers**: 248
- **Layers with recovery**: 66
- LoRA recovery: 2
- Difference recovery: 64