|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Custom PyTorch CLIP model with neuron ablation support.""" |
|
|
|
from typing import Optional, Tuple, Union |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from transformers.models.clip.modeling_clip import ( |
|
CLIPVisionModel, |
|
CLIPVisionTransformer, |
|
CLIPVisionEmbeddings, |
|
CLIPEncoder, |
|
CLIPEncoderLayer, |
|
CLIPMLP, |
|
CLIPVisionConfig, |
|
BaseModelOutputWithPooling |
|
) |
|
from transformers.activations import ACT2FN |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class CustomCLIPVisionEmbeddings(CLIPVisionEmbeddings): |
|
def forward(self, pixel_values: torch.FloatTensor, extra_tokens=0) -> torch.Tensor: |
|
batch_size = pixel_values.shape[0] |
|
target_dtype = self.patch_embedding.weight.dtype |
|
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) |
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) |
|
|
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1) |
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) |
|
embeddings = embeddings + self.position_embedding(self.position_ids) |
|
|
|
|
|
if extra_tokens > 0: |
|
extra_token_embeddings = [] |
|
for i in range(extra_tokens): |
|
|
|
extra_token_embeddings.append( |
|
torch.zeros( |
|
batch_size, 1, self.embed_dim, dtype=embeddings.dtype, device=embeddings.device |
|
) |
|
) |
|
|
|
|
|
embeddings = torch.cat([embeddings, *extra_token_embeddings], dim=1) |
|
|
|
return embeddings |
|
|
|
|
|
class CustomCLIPMLP(CLIPMLP): |
|
def __init__(self, config, layer_id=None): |
|
super().__init__(config) |
|
self.layer_id = layer_id |
|
|
|
def forward(self, hidden_states: torch.Tensor, neuron_dict=None, extra_tokens=0) -> torch.Tensor: |
|
hidden_states = self.fc1(hidden_states) |
|
|
|
if neuron_dict is not None and self.layer_id in neuron_dict: |
|
|
|
neurons = neuron_dict[self.layer_id] |
|
|
|
|
|
x_after_activation = self.activation_fn(hidden_states) |
|
|
|
|
|
original_activations = x_after_activation.clone() |
|
|
|
|
|
new_activation_map = torch.zeros( |
|
(x_after_activation.shape[0], x_after_activation.shape[1], len(neurons)), |
|
device=x_after_activation.device, |
|
).to(x_after_activation.dtype) |
|
|
|
|
|
max_values = torch.max(original_activations[:, :, neurons], dim=1, keepdim=True).values |
|
|
|
|
|
|
|
new_activation_map[:, -extra_tokens:, :] = max_values |
|
|
|
|
|
new_activation_map[:, 0, :] = x_after_activation[:, 0, neurons] |
|
|
|
|
|
x_after_activation[:, :, neurons] = new_activation_map |
|
|
|
|
|
x = x_after_activation |
|
else: |
|
|
|
x = self.activation_fn(hidden_states) |
|
|
|
x = self.fc2(x) |
|
return x |
|
|
|
|
|
class CustomCLIPEncoderLayer(CLIPEncoderLayer): |
|
def __init__(self, config, layer_id=None): |
|
super().__init__(config) |
|
self.mlp = CustomCLIPMLP(config, layer_id=layer_id) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
causal_attention_mask: torch.Tensor, |
|
output_attentions: Optional[bool] = False, |
|
neuron_dict: Optional[dict] = None, |
|
extra_tokens: Optional[int] = 0 |
|
) -> Tuple[torch.FloatTensor]: |
|
residual = hidden_states |
|
|
|
hidden_states = self.layer_norm1(hidden_states) |
|
hidden_states, attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
causal_attention_mask=causal_attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.layer_norm2(hidden_states) |
|
hidden_states = self.mlp(hidden_states, neuron_dict=neuron_dict, extra_tokens=extra_tokens) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class CustomCLIPEncoder(CLIPEncoder): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.layers = nn.ModuleList([CustomCLIPEncoderLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) |
|
|
|
def forward( |
|
self, |
|
inputs_embeds, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
causal_attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
neuron_dict: Optional[dict] = None, |
|
extra_tokens: Optional[int] = 0 |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
encoder_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
|
|
hidden_states = inputs_embeds |
|
|
|
for idx, encoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
encoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
causal_attention_mask, |
|
output_attentions, |
|
neuron_dict, |
|
) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
attention_mask, |
|
causal_attention_mask, |
|
output_attentions=output_attentions, |
|
neuron_dict=neuron_dict, |
|
extra_tokens=extra_tokens |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
|
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions |
|
) |
|
|
|
|
|
class CustomCLIPVisionTransformer(CLIPVisionTransformer): |
|
def __init__(self, config: CLIPVisionConfig): |
|
super().__init__(config) |
|
self.embeddings = CustomCLIPVisionEmbeddings(config) |
|
self.encoder = CustomCLIPEncoder(config) |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
extra_tokens: int = 0, |
|
neuron_dict: Optional[dict] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if pixel_values is None: |
|
raise ValueError("You have to specify pixel_values") |
|
|
|
hidden_states = self.embeddings(pixel_values, extra_tokens) |
|
hidden_states = self.pre_layrnorm(hidden_states) |
|
|
|
encoder_outputs = self.encoder( |
|
inputs_embeds=hidden_states, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
neuron_dict=neuron_dict, |
|
extra_tokens=extra_tokens |
|
) |
|
|
|
last_hidden_state = encoder_outputs[0] |
|
pooled_output = last_hidden_state[:, 0, :] |
|
pooled_output = self.post_layernorm(pooled_output) |
|
|
|
if not return_dict: |
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class CustomCLIPVisionModel(CLIPVisionModel): |
|
def __init__(self, config: CLIPVisionConfig): |
|
super().__init__(config) |
|
self.vision_model = CustomCLIPVisionTransformer(config) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
extra_tokens: int = 0, |
|
neuron_dict: Optional[dict] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
return self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
extra_tokens=extra_tokens, |
|
neuron_dict=neuron_dict, |
|
) |