|
""" |
|
Custom CLIP Model with Register Tokens - Import Safe Version with Complete File Download |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from transformers.utils import logging |
|
from typing import Optional, Union, Tuple |
|
import json |
|
from pathlib import Path |
|
import warnings |
|
import os |
|
import sys |
|
import importlib.util |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
os.environ["TRANSFORMERS_VERBOSITY"] = "error" |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
def ensure_all_files_downloaded(): |
|
"""Ensure all repository files are downloaded when this module is imported""" |
|
try: |
|
from huggingface_hub import snapshot_download, HfApi |
|
|
|
repo_id = 'amildravid4292/clip-vitb16-test-time-registers' |
|
|
|
|
|
api = HfApi() |
|
all_files = api.list_repo_files(repo_id) |
|
|
|
|
|
print(f"Ensuring all {len(all_files)} repository files are available...") |
|
|
|
local_dir = snapshot_download( |
|
repo_id=repo_id, |
|
resume_download=True, |
|
force_download=False |
|
) |
|
|
|
print(f"✓ Repository files available at: {local_dir}") |
|
|
|
|
|
if str(local_dir) not in sys.path: |
|
sys.path.insert(0, str(local_dir)) |
|
print(f"✓ Added repository directory to Python path: {local_dir}") |
|
|
|
|
|
critical_files = [f for f in all_files if f.endswith(('.py', '.pt', '.json'))] |
|
missing_critical = [] |
|
|
|
for file in critical_files: |
|
file_path = Path(local_dir) / file |
|
if not file_path.exists(): |
|
missing_critical.append(file) |
|
|
|
if missing_critical: |
|
print(f"Warning: {len(missing_critical)} critical files still missing") |
|
|
|
from huggingface_hub import hf_hub_download |
|
for file in missing_critical[:5]: |
|
try: |
|
hf_hub_download(repo_id=repo_id, filename=file, force_download=True) |
|
print(f"✓ Downloaded {file}") |
|
except Exception as e: |
|
print(f"✗ Could not download {file}: {e}") |
|
else: |
|
print(f"✓ All {len(critical_files)} critical files verified present") |
|
|
|
|
|
python_files = [f for f in all_files if f.endswith('.py')] |
|
print(f"✓ Python files available: {python_files}") |
|
|
|
return local_dir |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not verify/download all repository files: {e}") |
|
print("Model may still work if core files are present.") |
|
return None |
|
|
|
|
|
_repo_dir = ensure_all_files_downloaded() |
|
|
|
def safe_import_from_repo(module_name, repo_path): |
|
"""Safely import a module from the downloaded repository""" |
|
|
|
|
|
global _repo_dir |
|
if _repo_dir and str(_repo_dir) not in sys.path: |
|
sys.path.insert(0, str(_repo_dir)) |
|
print(f"✓ Added {_repo_dir} to Python path") |
|
|
|
try: |
|
|
|
return __import__(module_name) |
|
except ImportError: |
|
try: |
|
|
|
search_paths = [ |
|
Path(__file__).parent, |
|
Path(__file__).parent.parent, |
|
] |
|
|
|
|
|
if _repo_dir: |
|
search_paths.append(Path(_repo_dir)) |
|
|
|
|
|
try: |
|
from transformers.utils import TRANSFORMERS_CACHE |
|
repo_cache_name = "models--amildravid4292--clip-vitb16-test-time-registers" |
|
cache_path = Path(TRANSFORMERS_CACHE) / repo_cache_name / "snapshots" |
|
|
|
|
|
if cache_path.exists(): |
|
snapshot_dirs = [d for d in cache_path.iterdir() if d.is_dir()] |
|
if snapshot_dirs: |
|
|
|
latest_snapshot = max(snapshot_dirs, key=lambda x: x.stat().st_mtime) |
|
search_paths.append(latest_snapshot) |
|
|
|
if str(latest_snapshot) not in sys.path: |
|
sys.path.insert(0, str(latest_snapshot)) |
|
except: |
|
pass |
|
|
|
|
|
for search_dir in search_paths: |
|
module_path = search_dir / f"{module_name}.py" |
|
if module_path.exists(): |
|
|
|
if str(search_dir) not in sys.path: |
|
sys.path.insert(0, str(search_dir)) |
|
|
|
|
|
try: |
|
return __import__(module_name) |
|
except ImportError: |
|
|
|
spec = importlib.util.spec_from_file_location(module_name, module_path) |
|
module = importlib.util.module_from_spec(spec) |
|
sys.modules[module_name] = module |
|
spec.loader.exec_module(module) |
|
print(f"✓ Successfully imported {module_name} from {search_dir}") |
|
return module |
|
|
|
|
|
searched_locations = [str(p) for p in search_paths] |
|
raise ImportError(f"Could not find {module_name}.py in any of these locations: {searched_locations}") |
|
|
|
except Exception as e: |
|
raise ImportError(f"Failed to import {module_name}: {e}") |
|
|
|
class CustomCLIPConfig(PretrainedConfig): |
|
model_type = "custom_clip_with_registers" |
|
|
|
def __init__( |
|
self, |
|
vision_config=None, |
|
text_config=None, |
|
num_register_tokens=0, |
|
neuron_dict=None, |
|
projection_dim=512, |
|
logit_scale_init_value=2.6592, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.vision_config = vision_config or {} |
|
self.text_config = text_config or {} |
|
self.num_register_tokens = num_register_tokens |
|
self.neuron_dict = neuron_dict |
|
self.projection_dim = projection_dim |
|
self.logit_scale_init_value = logit_scale_init_value |
|
|
|
class CustomCLIPModel(PreTrainedModel): |
|
config_class = CustomCLIPConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
try: |
|
model_module = safe_import_from_repo('model', Path(__file__).parent) |
|
self.CLIP = model_module.CLIP |
|
self.CLIPVisionCfg = model_module.CLIPVisionCfg |
|
self.CLIPTextCfg = model_module.CLIPTextCfg |
|
except ImportError as e: |
|
raise ImportError(f"Could not import model components: {e}. Make sure all model files are in the repository.") |
|
|
|
|
|
vision_cfg = self.CLIPVisionCfg( |
|
layers=config.vision_config.get("num_hidden_layers", 12), |
|
width=config.vision_config.get("hidden_size", 768), |
|
patch_size=config.vision_config.get("patch_size", 16), |
|
image_size=config.vision_config.get("image_size", 224), |
|
) |
|
|
|
text_cfg = self.CLIPTextCfg( |
|
context_length=config.text_config.get("max_position_embeddings", 77), |
|
vocab_size=config.text_config.get("vocab_size", 49408), |
|
width=config.text_config.get("hidden_size", 512), |
|
layers=config.text_config.get("num_hidden_layers", 12), |
|
) |
|
|
|
|
|
self.model = self.CLIP( |
|
embed_dim=config.projection_dim, |
|
vision_cfg=vision_cfg, |
|
text_cfg=text_cfg, |
|
) |
|
|
|
|
|
self.neuron_dict = None |
|
self.num_register_tokens = 0 |
|
|
|
|
|
self._tokenizer = None |
|
self._preprocessor = None |
|
self._zeroshot_classifier = None |
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
|
"""Override to handle custom parameters and load weights properly""" |
|
|
|
|
|
if 'neuron_dict' in state_dict: |
|
self.neuron_dict = state_dict.pop('neuron_dict') |
|
|
|
if 'num_register_tokens' in state_dict: |
|
self.num_register_tokens = state_dict.pop('num_register_tokens') |
|
|
|
|
|
if hasattr(self.model, 'visual'): |
|
self.model.visual.num_register_tokens = self.num_register_tokens |
|
self.model.visual.neuron_dict = self.neuron_dict |
|
self.model.num_register_tokens = self.num_register_tokens |
|
self.model.neuron_dict = self.neuron_dict |
|
|
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
|
|
|
|
original_level = logging.get_verbosity() |
|
logging.set_verbosity_error() |
|
|
|
try: |
|
|
|
missing, unexpected = self.model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, False, [], [], []) |
|
finally: |
|
|
|
logging.set_verbosity(original_level) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
"""Override to load cleanly and suppress warnings""" |
|
|
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
|
|
|
|
original_level = logging.get_verbosity() |
|
logging.set_verbosity_error() |
|
|
|
try: |
|
|
|
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
finally: |
|
|
|
logging.set_verbosity(original_level) |
|
|
|
|
|
model._load_additional_components(pretrained_model_name_or_path) |
|
|
|
|
|
print("Custom CLIP model loaded successfully!") |
|
|
|
|
|
return model |
|
|
|
def _load_additional_components(self, pretrained_model_name_or_path): |
|
"""Load tokenizer, preprocessor, and zero-shot classifier silently""" |
|
|
|
try: |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
try: |
|
|
|
tokenizer_module = safe_import_from_repo('tokenizer', Path(__file__).parent) |
|
self._tokenizer = tokenizer_module.SimpleTokenizer() |
|
except ImportError: |
|
|
|
pass |
|
|
|
|
|
try: |
|
preprocess_config_file = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="preprocessor_config.json" |
|
) |
|
|
|
with open(preprocess_config_file, 'r') as f: |
|
preprocess_config = json.load(f) |
|
|
|
self._create_preprocessor(preprocess_config) |
|
except: |
|
pass |
|
|
|
|
|
try: |
|
classifier_file = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="zeroshot_classifier.pt" |
|
) |
|
|
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
self._zeroshot_classifier = torch.load(classifier_file, map_location='cpu', weights_only=False) |
|
except: |
|
pass |
|
|
|
except: |
|
pass |
|
|
|
def _create_preprocessor(self, config): |
|
"""Create image preprocessor from config""" |
|
try: |
|
from torchvision import transforms |
|
|
|
self._preprocessor = transforms.Compose([ |
|
transforms.Resize(config["image_size"], interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.CenterCrop(config["image_size"]), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=config["image_mean"], std=config["image_std"]), |
|
]) |
|
except: |
|
pass |
|
|
|
@property |
|
def tokenizer(self): |
|
"""Access the tokenizer""" |
|
return self._tokenizer |
|
|
|
@property |
|
def preprocessor(self): |
|
"""Access the image preprocessor""" |
|
return self._preprocessor |
|
|
|
@property |
|
def zeroshot_classifier(self): |
|
"""Access the zero-shot classifier""" |
|
return self._zeroshot_classifier |
|
|
|
def tokenize(self, texts, context_length=77): |
|
"""Tokenize text using the loaded tokenizer""" |
|
if self._tokenizer is None: |
|
raise ValueError("Tokenizer not available. Make sure tokenizer.py is in the repository.") |
|
|
|
|
|
try: |
|
tokenizer_module = safe_import_from_repo('tokenizer', Path(__file__).parent) |
|
return tokenizer_module.tokenize(texts, context_length) |
|
except ImportError: |
|
raise ValueError("Could not import tokenize function.") |
|
|
|
def preprocess_image(self, image): |
|
"""Preprocess image using the loaded preprocessor""" |
|
if self._preprocessor is None: |
|
raise ValueError("Preprocessor not loaded. Make sure preprocessor_config.json is in the repository.") |
|
|
|
return self._preprocessor(image) |
|
|
|
def forward(self, input_ids=None, pixel_values=None, num_register_tokens=None, neuron_dict=None, **kwargs): |
|
"""Forward pass supporting your custom functionality""" |
|
|
|
if num_register_tokens is None: |
|
num_register_tokens = self.num_register_tokens |
|
if neuron_dict is None: |
|
neuron_dict = self.neuron_dict |
|
|
|
return self.model( |
|
image=pixel_values, |
|
text=input_ids, |
|
num_register_tokens=num_register_tokens, |
|
neuron_dict=neuron_dict |
|
) |
|
|
|
def encode_image(self, pixel_values, num_register_tokens=None, neuron_dict=None, **kwargs): |
|
"""Encode images with register token support""" |
|
if num_register_tokens is None: |
|
num_register_tokens = self.num_register_tokens |
|
if neuron_dict is None: |
|
neuron_dict = self.neuron_dict |
|
|
|
return self.model.encode_image( |
|
pixel_values, |
|
num_register_tokens=num_register_tokens, |
|
neuron_dict=neuron_dict, |
|
**kwargs |
|
) |
|
|
|
def encode_text(self, input_ids, **kwargs): |
|
"""Encode text""" |
|
return self.model.encode_text(input_ids, **kwargs) |
|
|
|
|
|
import transformers |
|
transformers.logging.set_verbosity_error() |