Image-Text-to-Text
Transformers
Safetensors
llava_test_time_registers
text-generation
llava-llama-3-8b-test-time-registers / modeling_custom_clip.py
amildravid4292's picture
Rename custom_code/modeling_custom_clip.py to modeling_custom_clip.py
ba4c45d verified
# coding=utf-8
# Copyright 2023 Custom LLaVA Neuron Ablation Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Custom PyTorch CLIP model with neuron ablation support."""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import numpy as np
from transformers.models.clip.modeling_clip import (
CLIPVisionModel,
CLIPVisionTransformer,
CLIPVisionEmbeddings,
CLIPEncoder,
CLIPEncoderLayer,
CLIPMLP,
CLIPVisionConfig,
BaseModelOutputWithPooling
)
from transformers.activations import ACT2FN
from transformers.utils import logging
logger = logging.get_logger(__name__)
class CustomCLIPVisionEmbeddings(CLIPVisionEmbeddings):
def forward(self, pixel_values: torch.FloatTensor, extra_tokens=0) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
# Add extra tokens if requested (register tokens for neuron ablation)
if extra_tokens > 0:
extra_token_embeddings = []
for i in range(extra_tokens):
# Initialize extra tokens as zeros, which will serve as register tokens
extra_token_embeddings.append(
torch.zeros(
batch_size, 1, self.embed_dim, dtype=embeddings.dtype, device=embeddings.device
)
)
# Concatenate extra tokens to embeddings
embeddings = torch.cat([embeddings, *extra_token_embeddings], dim=1)
return embeddings
class CustomCLIPMLP(CLIPMLP):
def __init__(self, config, layer_id=None):
super().__init__(config)
self.layer_id = layer_id # Store layer ID for neuron targeting
def forward(self, hidden_states: torch.Tensor, neuron_dict=None, extra_tokens=0) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
if neuron_dict is not None and self.layer_id in neuron_dict:
# Get the neurons to modify for this layer
neurons = neuron_dict[self.layer_id]
# Apply activation function to all activations
x_after_activation = self.activation_fn(hidden_states)
# Store original activations
original_activations = x_after_activation.clone()
# Create new activation map for specified neurons
new_activation_map = torch.zeros(
(x_after_activation.shape[0], x_after_activation.shape[1], len(neurons)),
device=x_after_activation.device,
).to(x_after_activation.dtype)
# Get max values across the sequence dimension (excluding extra tokens)
max_values = torch.max(original_activations[:, :, neurons], dim=1, keepdim=True).values
# Set the register token activations to the max values
new_activation_map[:, -extra_tokens:, :] = max_values
# Keep class token activations unchanged
new_activation_map[:, 0, :] = x_after_activation[:, 0, neurons]
# Update the selected neurons with new activations
x_after_activation[:, :, neurons] = new_activation_map
# Use the modified activations
x = x_after_activation
else:
# Standard behavior for layers not in the dictionary
x = self.activation_fn(hidden_states)
x = self.fc2(x)
return x
class CustomCLIPEncoderLayer(CLIPEncoderLayer):
def __init__(self, config, layer_id=None):
super().__init__(config)
self.mlp = CustomCLIPMLP(config, layer_id=layer_id)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
neuron_dict: Optional[dict] = None,
extra_tokens: Optional[int] = 0
) -> Tuple[torch.FloatTensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states, neuron_dict=neuron_dict, extra_tokens=extra_tokens)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class CustomCLIPEncoder(CLIPEncoder):
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList([CustomCLIPEncoderLayer(config, layer_id=i) for i in range(config.num_hidden_layers)])
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
neuron_dict: Optional[dict] = None,
extra_tokens: Optional[int] = 0
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
neuron_dict,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
neuron_dict=neuron_dict,
extra_tokens=extra_tokens
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
from transformers.modeling_outputs import BaseModelOutput
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class CustomCLIPVisionTransformer(CLIPVisionTransformer):
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
self.embeddings = CustomCLIPVisionEmbeddings(config)
self.encoder = CustomCLIPEncoder(config)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
extra_tokens: int = 0,
neuron_dict: Optional[dict] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values, extra_tokens)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
neuron_dict=neuron_dict,
extra_tokens=extra_tokens
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class CustomCLIPVisionModel(CLIPVisionModel):
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
self.vision_model = CustomCLIPVisionTransformer(config)
self.post_init()
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
extra_tokens: int = 0,
neuron_dict: Optional[dict] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
extra_tokens=extra_tokens,
neuron_dict=neuron_dict,
)