Spaces:
Paused
Paused
from typing import Callable, Any | |
from functools import partial | |
import warnings | |
import torch | |
import torch.nn as nn | |
""" | |
Sparse Autoencoder (SAE) Implementation | |
This module implements various sparse autoencoder architectures and activation functions | |
designed to learn interpretable features in high-dimensional data. | |
""" | |
def normalize_data(x: torch.Tensor, eps: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Normalize input data to zero mean and unit variance. | |
Args: | |
x (torch.Tensor): Input tensor to normalize | |
eps (float, optional): Small constant for numerical stability. Defaults to 1e-5. | |
Returns: | |
tuple: (normalized_data, mean, std) | |
- normalized_data: Data normalized to zero mean and unit variance | |
- mean: Mean of the original data (for denormalization) | |
- std: Standard deviation of the original data (for denormalization) | |
""" | |
mu = x.mean(dim=-1, keepdim=True) | |
x = x - mu | |
std = x.std(dim=-1, keepdim=True) | |
x = x / (std + eps) | |
return x, mu, std | |
class SoftCapping(nn.Module): | |
""" | |
Soft capping layer to prevent latent activations from growing excessively large. | |
This layer applies a scaled tanh transformation that smoothly saturates values | |
without hard truncation, helping stabilize training. | |
Args: | |
soft_cap (float): The scale factor for the tanh transformation | |
""" | |
def __init__(self, soft_cap): | |
super(SoftCapping, self).__init__() | |
self.soft_cap = soft_cap | |
def forward(self, logits): | |
""" | |
Apply soft capping to input values. | |
Args: | |
logits (torch.Tensor): Input tensor | |
Returns: | |
torch.Tensor: Soft-capped values with range approximately [-soft_cap, soft_cap] | |
""" | |
return self.soft_cap * torch.tanh(logits / self.soft_cap) | |
class TopK(nn.Module): | |
""" | |
Top-K activation function that only keeps the K largest activations per sample. | |
This activation enforces sparsity by zeroing out all but the k highest values in each | |
input vector. Can optionally use absolute values for selection and apply a subsequent | |
activation function. | |
Args: | |
k (int): Number of activations to keep | |
act_fn (Callable, optional): Secondary activation function to apply to the kept values. | |
Defaults to nn.ReLU(). | |
use_abs (bool, optional): If True, selection is based on absolute values. Defaults to False. | |
""" | |
def __init__(self, k: int, act_fn: Callable = nn.ReLU(), use_abs: bool = False) -> None: | |
super().__init__() | |
self.k = k | |
self.act_fn = act_fn | |
self.use_abs = use_abs | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass that keeps only the top-k activations for each sample. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, features] | |
Returns: | |
torch.Tensor: Sparse output tensor with same shape as input, where all but | |
the top k values (per sample) are zero | |
""" | |
if self.use_abs: | |
x = torch.abs(x) | |
# Get indices of top-k values along feature dimension | |
_, indices = torch.topk(x, k=self.k, dim=-1) | |
# Gather the corresponding values from the original input | |
values = torch.gather(x, -1, indices) | |
# Apply the activation function to the selected values | |
activated_values = self.act_fn(values) | |
# Create a tensor of zeros and place the activated values at the correct positions | |
result = torch.zeros_like(x) | |
result.scatter_(-1, indices, activated_values) | |
# Verify sparsity constraint is met | |
assert (result != 0.0).sum(dim=-1).max() <= self.k | |
return result | |
def forward_eval(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Evaluation mode forward pass that doesn't enforce sparsity. | |
Used for computing full activations during evaluation or visualization. | |
Args: | |
x (torch.Tensor): Input tensor | |
Returns: | |
torch.Tensor: Output after applying activation function (without top-k filtering) | |
""" | |
if self.use_abs: | |
x = torch.abs(x) | |
x = self.act_fn(x) | |
return x | |
# Mapping of activation function names to their corresponding classes | |
ACTIVATIONS_CLASSES = { | |
"ReLU": nn.ReLU, | |
"Identity": nn.Identity, | |
"TopK": partial(TopK, act_fn=nn.Identity()), | |
"TopKReLU": partial(TopK, act_fn=nn.ReLU()), | |
"TopKabs": partial(TopK, use_abs=True, act_fn=nn.Identity()), | |
"TopKabsReLU": partial(TopK, use_abs=True, act_fn=nn.ReLU()), | |
} | |
def get_activation(activation: str) -> nn.Module: | |
""" | |
Factory function to create activation function instances by name. | |
Handles special cases like parameterized activations (e.g., TopK_64). | |
Args: | |
activation (str): Name of the activation function, with optional parameter | |
(e.g., "TopKReLU_64" for TopKReLU with k=64) | |
Returns: | |
nn.Module: Instantiated activation function | |
""" | |
if "_" in activation: | |
activation, arg = activation.split("_") | |
if "TopK" in activation: | |
return ACTIVATIONS_CLASSES[activation](k=int(arg)) | |
elif "JumpReLU" in activation: | |
return ACTIVATIONS_CLASSES[activation](hidden_dim=int(arg)) | |
return ACTIVATIONS_CLASSES[activation]() | |
class Autoencoder(nn.Module): | |
""" | |
Sparse autoencoder base class. | |
Implements the standard sparse autoencoder architecture: | |
latents = activation(encoder(x - pre_bias) + latent_bias) | |
recons = decoder(latents) + pre_bias | |
Includes various options for controlling activation functions, weight initialization, | |
and feature normalization. | |
Attributes: | |
n_latents (int): Number of latent features (neurons) | |
n_inputs (int): Dimensionality of the input data | |
tied (bool): Whether decoder weights are tied to encoder weights | |
normalize (bool): Whether to normalize input data | |
encoder (nn.Parameter): Encoder weight matrix [n_inputs, n_latents] | |
decoder (nn.Parameter): Decoder weight matrix [n_latents, n_inputs] (if not tied) | |
pre_bias (nn.Parameter): Input bias/offset [n_inputs] | |
latent_bias (nn.Parameter): Latent bias [n_latents] | |
activation (nn.Module): Activation function for the latent layer | |
latents_activation_frequency (torch.Tensor): Tracks how often neurons activate | |
""" | |
def __init__( | |
self, n_latents: int, n_inputs: int, activation: Callable = nn.ReLU(), tied: bool = False, normalize: bool = False, | |
bias_init: torch.Tensor | float = 0.0, init_method: str = "kaiming", latent_soft_cap: float = 30.0, threshold: torch.Tensor | None = None, | |
*args, **kwargs | |
) -> None: | |
""" | |
Initialize the sparse autoencoder. | |
Args: | |
n_latents (int): Dimension of the autoencoder latent space | |
n_inputs (int): Dimensionality of the original data | |
activation (Callable or str): Activation function or name | |
tied (bool, optional): Whether to tie encoder and decoder weights. Defaults to False. | |
normalize (bool, optional): Whether to normalize input data. Defaults to False. | |
bias_init (torch.Tensor | float, optional): Initial bias value. Defaults to 0.0. | |
init_method (str, optional): Weight initialization method. Defaults to "kaiming". | |
latent_soft_cap (float, optional): Soft cap value for latent activations. Defaults to 30.0. | |
threshold (torch.Tensor, optional): Threshold for JumpReLU. Defaults to None. | |
""" | |
super().__init__() | |
if isinstance(activation, str): | |
activation = get_activation(activation) | |
# Store configuration | |
self.tied = tied | |
self.n_latents = n_latents | |
self.n_inputs = n_inputs | |
self.init_method = init_method | |
self.bias_init = bias_init | |
self.normalize = normalize | |
# Initialize parameters | |
self.pre_bias = nn.Parameter(torch.full((n_inputs,), bias_init) if isinstance(bias_init, float) else bias_init) | |
self.encoder = nn.Parameter(torch.zeros((n_inputs, n_latents))) | |
self.latent_bias = nn.Parameter(torch.zeros(n_latents,)) | |
# For tied weights, decoder is derived from encoder | |
if tied: | |
self.register_parameter('decoder', None) | |
else: | |
self.decoder = nn.Parameter(torch.zeros((n_latents, n_inputs))) | |
# Set up activation functions | |
self.latent_soft_cap = SoftCapping(latent_soft_cap) if latent_soft_cap > 0 else nn.Identity() | |
self.activation = activation | |
self.dead_activations = activation | |
# Initialize weights | |
self._init_weights() | |
# Set up activation tracking | |
self.latents_activation_frequency: torch.Tensor | |
self.register_buffer( | |
"latents_activation_frequency", torch.zeros(n_latents, dtype=torch.int64, requires_grad=False) | |
) | |
self.num_updates = 0 | |
self.dead_latents = [] | |
def get_and_reset_stats(self) -> torch.Tensor: | |
""" | |
Get activation statistics and reset the counters. | |
Returns: | |
torch.Tensor: Proportion of samples that activated each neuron | |
""" | |
activations = self.latents_activation_frequency.detach().cpu().float() / self.num_updates | |
self.latents_activation_frequency.zero_() | |
self.num_updates = 0 | |
return activations | |
def _init_weights(self, norm=0.1, neuron_indices: list[int] | None = None) -> None: | |
""" | |
Initialize network weights. | |
Args: | |
norm (float, optional): Target norm for the weights. Defaults to 0.1. | |
neuron_indices (list[int] | None, optional): Indices of neurons to initialize. | |
If None, initialize all neurons. | |
Raises: | |
ValueError: If invalid initialization method is specified | |
""" | |
if self.init_method not in ["kaiming", "xavier", "uniform", "normal"]: | |
raise ValueError(f"Invalid init_method: {self.init_method}") | |
# Use transposed encoder if weights are tied | |
if self.tied: | |
decoder_weight = self.encoder.t() | |
else: | |
decoder_weight = self.decoder | |
# Initialize with specified method | |
if self.init_method == "kaiming": | |
new_W_dec = (nn.init.kaiming_uniform_(torch.zeros_like(decoder_weight), nonlinearity='relu')) | |
elif self.init_method == "xavier": | |
new_W_dec = (nn.init.xavier_uniform_(torch.zeros_like(decoder_weight), gain=nn.init.calculate_gain('relu'))) | |
elif self.init_method == "uniform": | |
new_W_dec = (nn.init.uniform_(torch.zeros_like(decoder_weight), a=-1, b=1)) | |
elif self.init_method == "normal": | |
new_W_dec = (nn.init.normal_(torch.zeros_like(decoder_weight))) | |
else: | |
raise ValueError(f"Invalid init_method: {self.init_method}") | |
# Normalize to target norm | |
new_W_dec *= (norm / new_W_dec.norm(p=2, dim=-1, keepdim=True)) | |
# Initialize bias to zero | |
new_l_bias = (torch.zeros_like(self.latent_bias)) | |
# Transpose for encoder | |
new_W_enc = new_W_dec.t().clone() | |
# Apply initialization to all or specific neurons | |
if neuron_indices is None: | |
if not self.tied: | |
self.decoder.data = new_W_dec | |
self.encoder.data = new_W_enc | |
self.latent_bias.data = new_l_bias | |
else: | |
if not self.tied: | |
self.decoder.data[neuron_indices] = new_W_dec[neuron_indices] | |
self.encoder.data[:, neuron_indices] = new_W_enc[:, neuron_indices] | |
self.latent_bias.data[neuron_indices] = new_l_bias[neuron_indices] | |
def project_grads_decode(self): | |
""" | |
Project out components of decoder gradient that would change its norm. | |
This helps maintain normalized decoder norms during training. | |
""" | |
if self.tied: | |
weights = self.encoder.data.T | |
grad = self.encoder.grad.T | |
else: | |
weights = self.decoder.data | |
grad = self.decoder.grad | |
# Project out the component parallel to weights | |
grad_proj = (grad * weights).sum(dim=-1, keepdim=True) * weights | |
# Update gradients | |
if self.tied: | |
self.encoder.grad -= grad_proj.T | |
else: | |
self.decoder.grad -= grad_proj | |
def scale_to_unit_norm(self) -> None: | |
""" | |
Scale decoder rows to unit norm, and adjust other parameters accordingly. | |
This normalization helps with feature interpretability and training stability. | |
""" | |
eps = torch.finfo(self.decoder.dtype).eps | |
# Normalize tied or untied weights | |
if self.tied: | |
norm = self.encoder.data.T.norm(p=2, dim=-1, keepdim=True) + eps | |
self.encoder.data.T /= norm | |
else: | |
norm = self.decoder.data.norm(p=2, dim=-1, keepdim=True) + eps | |
self.decoder.data /= norm | |
self.encoder.data *= norm.t() | |
# Scale biases accordingly | |
self.latent_bias.data *= norm.squeeze() | |
def encode_pre_act(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Compute pre-activation latent values. | |
Args: | |
x (torch.Tensor): Input data [batch, n_inputs] | |
Returns: | |
torch.Tensor: Pre-activation latent values [batch, n_latents] | |
""" | |
x = x - self.pre_bias | |
latents_pre_act_full = x @ self.encoder + self.latent_bias | |
return latents_pre_act_full | |
def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]: | |
""" | |
Preprocess input data, optionally normalizing it. | |
Args: | |
x (torch.Tensor): Input data [batch, n_inputs] | |
Returns: | |
tuple: (preprocessed_data, normalization_info) | |
- preprocessed_data: Processed input data | |
- normalization_info: Dict with normalization parameters (if normalize=True) | |
""" | |
if not self.normalize: | |
return x, dict() | |
x_processed, mu, std = normalize_data(x) | |
return x_processed, dict(mu=mu, std=std) | |
def encode(self, x: torch.Tensor, topk_number: int | None = None) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: | |
""" | |
Encode input data to latent representations. | |
Args: | |
x (torch.Tensor): Input data [batch, n_inputs] | |
topk_number (int | None, optional): Number of top-k activations to keep (for inference). | |
Defaults to None. | |
Returns: | |
tuple: (encoded, full_encoded, info) | |
- encoded: Latent activations with sparsity constraints [batch, n_latents] | |
- full_encoded: Latent activations without sparsity (for analysis) [batch, n_latents] | |
- info: Normalization information dictionary | |
""" | |
x, info = self.preprocess(x) | |
pre_encoded = self.encode_pre_act(x) | |
encoded = self.activation(pre_encoded) | |
# Get full activations (for analysis) depending on activation type | |
if isinstance(self.activation, TopK): | |
full_encoded = self.activation.forward_eval(pre_encoded) | |
else: | |
full_encoded = torch.clone(encoded) | |
# Apply topk filtering for inference if requested | |
if topk_number is not None: | |
_, indices = torch.topk(full_encoded, k=topk_number, dim=-1) | |
values = torch.gather(full_encoded, -1, indices) | |
full_encoded = torch.zeros_like(full_encoded) | |
full_encoded.scatter_(-1, indices, values) | |
# Apply soft capping to both outputs | |
caped_encoded = self.latent_soft_cap(encoded) | |
capped_full_encoded = self.latent_soft_cap(full_encoded) | |
return caped_encoded, capped_full_encoded, info | |
def decode(self, latents: torch.Tensor, info: dict[str, Any] | None = None) -> torch.Tensor: | |
""" | |
Decode latent representations to reconstructed inputs. | |
Args: | |
latents (torch.Tensor): Latent activations [batch, n_latents] | |
info (dict[str, Any] | None, optional): Normalization information. Defaults to None. | |
Returns: | |
torch.Tensor: Reconstructed input data [batch, n_inputs] | |
""" | |
# Decode using tied or untied weights | |
if self.tied: | |
ret = latents @ self.encoder.t() + self.pre_bias | |
else: | |
ret = latents @ self.decoder + self.pre_bias | |
# Denormalize if needed | |
if self.normalize: | |
assert info is not None | |
ret = ret * info["std"] + info["mu"] | |
return ret | |
def update_latent_statistics(self, latents: torch.Tensor) -> None: | |
""" | |
Update statistics on latent activations. | |
Args: | |
latents (torch.Tensor): Latent activations [batch, n_latents] | |
""" | |
self.num_updates += latents.shape[0] | |
current_activation_frequency = (latents != 0).to(torch.int64).sum(dim=0) | |
self.latents_activation_frequency += current_activation_frequency | |
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Forward pass through the autoencoder. | |
Args: | |
x (torch.Tensor): Input data [batch, n_inputs] | |
Returns: | |
tuple: (recons, latents, all_recons, all_latents) | |
- recons: Reconstructed data [batch, n_inputs] | |
- latents: Latent activations [batch, n_latents] | |
- all_recons: Reconstructed data without sparsity constraints (for analysis) | |
- all_latents: Latent activations without sparsity constraints (for analysis) | |
""" | |
# Preprocess data | |
x_processed, info = self.preprocess(x) | |
# Compute pre-activations | |
latents_pre_act = self.encode_pre_act(x_processed) | |
# Apply activation function | |
latents = self.activation(latents_pre_act) | |
latents_caped = self.latent_soft_cap(latents) | |
# Decode to reconstruction | |
recons = self.decode(latents_caped, info) | |
# Update activation statistics | |
self.update_latent_statistics(latents_caped) | |
# Handle different activation function types for analysis outputs | |
if isinstance(self.activation, TopK): | |
# For TopK, return both sparse and full activations | |
all_latents = self.activation.forward_eval(latents_pre_act) | |
all_latents_caped = self.latent_soft_cap(all_latents) | |
all_recons = self.decode(all_latents_caped, info) | |
return recons, latents_caped, all_recons, all_latents_caped | |
else: | |
# For other activations, return the same for both | |
return recons, latents_caped, recons, latents_caped | |
class MatryoshkaAutoencoder(Autoencoder): | |
""" | |
Matryoshka Sparse Autoencoder. | |
This extends the base Autoencoder with a nested structure of latent representations, | |
where different numbers of features can be used depending on computational budget | |
or desired level of detail. | |
The model uses multiple TopK activations with different k values and maintains | |
relative importance weights for each level of the hierarchy. | |
""" | |
def __init__( | |
self, n_latents: int, n_inputs: int, activation: str = "TopKReLU", tied: bool = False, normalize: bool = False, | |
bias_init: torch.Tensor | float = 0.0, init_method: str = "kaiming", latent_soft_cap: float = 30.0, | |
nesting_list: list[int] = [16, 32], relative_importance: list[float] | None = None, *args, **kwargs | |
) -> None: | |
""" | |
Initialize the Matryoshka Sparse Autoencoder. | |
Args: | |
n_latents (int): Dimension of the autoencoder latent space | |
n_inputs (int): Dimensionality of the original data | |
activation (str, optional): Base activation function name. Defaults to "TopKReLU". | |
tied (bool, optional): Whether to tie encoder and decoder weights. Defaults to False. | |
normalize (bool, optional): Whether to normalize input data. Defaults to False. | |
bias_init (torch.Tensor | float, optional): Initial bias value. Defaults to 0.0. | |
init_method (str, optional): Weight initialization method. Defaults to "kaiming". | |
latent_soft_cap (float, optional): Soft cap value for latent activations. Defaults to 30.0. | |
nesting_list (list[int], optional): List of k values for nested representations. Defaults to [16, 32]. | |
relative_importance (list[float] | None, optional): Importance weights for each nesting level. | |
Defaults to equal weights. | |
""" | |
# Initialize nesting hierarchy | |
self.nesting_list = sorted(nesting_list) | |
self.relative_importance = relative_importance if relative_importance is not None else [1.0] * len(nesting_list) | |
assert len(self.relative_importance) == len(self.nesting_list) | |
# Ensure activation is TopK-based | |
if "TopK" not in activation: | |
warnings.warn(f"MatryoshkaAutoencoder: activation {activation} is not a TopK activation. We are changing it to TopKReLU") | |
activation = "TopKReLU" | |
# Initialize with base activation | |
base_activation = activation + f"_{self.nesting_list[0]}" | |
super().__init__(n_latents, n_inputs, base_activation, tied, normalize, bias_init, init_method, latent_soft_cap) | |
# Create multiple activations with different k values | |
self.activation = nn.ModuleList( | |
[get_activation(activation + f"_{nesting}") for nesting in self.nesting_list] | |
) | |
def encode(self, x: torch.Tensor, topk_number: int | None = None) -> tuple[list[torch.Tensor], torch.Tensor, dict[str, Any]]: | |
""" | |
Encode input data to multiple latent representations with different sparsity levels. | |
Args: | |
x (torch.Tensor): Input data [batch, n_inputs] | |
topk_number (int | None, optional): Number of top-k activations to keep (for inference). | |
Defaults to None. | |
Returns: | |
tuple: (encoded_list, last_encoded, info) | |
- encoded_list: List of latent activations with different sparsity levels | |
- last_encoded: The least sparse latent activations (from largest k value) | |
- info: Normalization information dictionary | |
""" | |
x, info = self.preprocess(x) | |
pre_encoded = self.encode_pre_act(x) | |
# Apply each activation function in the hierarchy | |
encoded = [activation(pre_encoded) for activation in self.activation] | |
caped_encoded = [self.latent_soft_cap(enc) for enc in encoded] | |
# Apply additional top-k filtering for inference if requested | |
if topk_number is not None: | |
last_encoded = caped_encoded[-1] | |
_, indices = torch.topk(last_encoded, k=topk_number, dim=-1) | |
values = torch.gather(last_encoded, -1, indices) | |
last_encoded = torch.zeros_like(last_encoded) | |
last_encoded.scatter_(-1, indices, values) | |
else: | |
last_encoded = caped_encoded[-1] | |
return caped_encoded, last_encoded, info | |
def decode(self, latents: list[torch.Tensor], info: dict[str, Any] | None = None) -> list[torch.Tensor]: | |
""" | |
Decode multiple latent representations to reconstructions. | |
Args: | |
latents (list[torch.Tensor]): List of latent activations at different sparsity levels | |
info (dict[str, Any] | None, optional): Normalization information. Defaults to None. | |
Returns: | |
list[torch.Tensor]: List of reconstructed inputs at different sparsity levels | |
""" | |
# Decode each latent representation | |
if self.tied: | |
ret = [latent @ self.encoder.t() + self.pre_bias for latent in latents] | |
else: | |
ret = [latent @ self.decoder + self.pre_bias for latent in latents] | |
# Denormalize if needed | |
if self.normalize: | |
assert info is not None | |
ret = [re * info["std"] + info["mu"] for re in ret] | |
return ret | |
def forward(self, x: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor, torch.Tensor]: | |
""" | |
Forward pass through the Matryoshka autoencoder. | |
Args: | |
x (torch.Tensor): Input data [batch, n_inputs] | |
Returns: | |
tuple: (recons_list, latents_list, final_recon, final_latent) | |
- recons_list: List of reconstructions at different sparsity levels | |
- latents_list: List of latent activations at different sparsity levels | |
- final_recon: Reconstruction from the largest k value | |
- final_latent: Latent activations from the largest k value | |
""" | |
# Preprocess data | |
x_processed, info = self.preprocess(x) | |
latents_pre_act = self.encode_pre_act(x_processed) | |
# Apply each activation in the hierarchy | |
latents = [activation(latents_pre_act) for activation in self.activation] | |
assert len(latents) == len(self.activation) | |
latents_caped = [self.latent_soft_cap(latent) for latent in latents] | |
# Decode each level | |
recons = self.decode(latents_caped, info) | |
assert len(recons) == len(latents) | |
# Update activation statistics using the largest k | |
self.update_latent_statistics(latents_caped[-1]) | |
# Get full activations for analysis | |
all_latents = self.activation[0].forward_eval(latents_pre_act) | |
all_latents_caped = self.latent_soft_cap(all_latents) | |
all_recons = self.decode([all_latents_caped], info)[0] | |
# Return all reconstructions and the final ones | |
return recons, latents_caped, all_recons, all_latents_caped | |
def load_model(path): | |
""" | |
Load a saved sparse autoencoder model from a file. | |
This function parses the filename to extract model configuration parameters | |
and then loads the saved model weights. | |
Args: | |
path (str): Path to the saved model file (.pt) | |
Returns: | |
tuple: (model, data_mean_center, data_normalized, scaling_factor) | |
- model: The loaded Autoencoder model | |
- mean_center: Boolean indicating if data was mean-centered | |
- target_norm: Target normalization factor for the data | |
""" | |
# Extract configuration from filename | |
path_head = path.split("/")[-1] | |
path_name = path_head[:path_head.find(".pt")] | |
path_name_spited = path_name.split("_") | |
n_latents = int(path_name_spited.pop(0)) | |
n_inputs = int(path_name_spited.pop(0)) | |
activation = path_name_spited.pop(0) | |
if "TopK" in activation: | |
activation += "_" + path_name_spited.pop(0) | |
elif "ReLU" == activation: | |
path_name_spited.pop(0) | |
if "UW" in path_name_spited[0] or "RW" in path_name_spited[0]: | |
path_name_spited.pop(0) | |
tied = False if path_name_spited.pop(0) == "False" else True | |
normalize = False if path_name_spited.pop(0) == "False" else True | |
latent_soft_cap = float(path_name_spited.pop(0)) | |
# Create and load the model | |
model = Autoencoder(n_latents, n_inputs, activation, tied=tied, normalize=normalize, latent_soft_cap=latent_soft_cap) | |
model_state_dict = torch.load(path, map_location='cuda' if torch.cuda.is_available() else 'cpu') | |
model.load_state_dict(model_state_dict['model']) | |
mean_center = model_state_dict['mean_center'] | |
scaling_factor = model_state_dict['scaling_factor'] | |
target_norm = model_state_dict['target_norm'] | |
return model, mean_center, scaling_factor, target_norm | |
class SAE(nn.Module): | |
def __init__(self, path: str) -> None: | |
""" | |
Initialize the Sparse Autoencoder (SAE) model. | |
Args: | |
path (str): Path to the saved model file (.pt) | |
""" | |
super().__init__() | |
self.model, mean, scaling_factor, _ = load_model(path) | |
self.register_buffer("mean", mean.clone().detach() if isinstance(mean, torch.Tensor) else torch.tensor(mean)) | |
self.register_buffer("scaling_factor", torch.tensor(scaling_factor)) | |
def input_dim(self) -> int: | |
"""Return input dimension of the model.""" | |
return self.model.n_inputs | |
def latent_dim(self) -> int: | |
"""Return latent dimension of the model.""" | |
return self.model.n_latents | |
def preprocess(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Preprocess input data (mean-centering and scaling). | |
Args: | |
x: Input tensor | |
Returns: | |
Preprocessed tensor | |
""" | |
# Mean-center and scale the input | |
x = (x - self.mean) * self.scaling_factor | |
return x | |
def postprocess(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Post-process output data (denormalization). | |
Args: | |
x: Output tensor | |
Returns: | |
Denormalized tensor | |
""" | |
# Rescale and mean-center the output | |
x = x / self.scaling_factor + self.mean | |
return x | |
def encode(self, x: torch.Tensor, topk: int = -1) -> tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Encode input data to latent representation. | |
Args: | |
x: Input tensor | |
topk (int, optional): Number of top-k activations to keep. Defaults to -1 (no sparsity). | |
Returns: | |
Encoded latents and full latents | |
""" | |
# Preprocess input | |
x = self.preprocess(x) | |
# Validate topk constrain | |
if topk > 0 and topk < self.model.n_latents: | |
topk_number = topk | |
else: | |
topk_number = None | |
# Encode using the model | |
latents, full_latents, _ = self.model.encode(x, topk_number=topk_number) | |
return latents, full_latents | |
def decode(self, latents: torch.Tensor) -> torch.Tensor: | |
""" | |
Decode latent representation to input space. | |
Args: | |
latents: Latent tensor | |
Returns: | |
Reconstructed input tensor | |
""" | |
# Decode using the model | |
reconstructed = self.model.decode(latents) | |
# Postprocess output | |
reconstructed = self.postprocess(reconstructed) | |
return reconstructed | |
def forward(self, x: torch.Tensor, topk: int = -1) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Forward pass through the SAE. | |
Args: | |
x: Input tensor | |
topk (int, optional): Number of top-k activations to keep. Defaults to -1 (no sparsity). | |
Returns: | |
- Post-processed reconstructed tensor | |
- Reconstructed tensor | |
- Full latent activations | |
""" | |
# Encode to latent space | |
_, full_latents = self.encode(x, topk=topk) | |
# Decode back to input space | |
reconstructed = self.model.decode(full_latents) | |
# Postprocess output | |
post_reconstructed = self.postprocess(reconstructed) | |
# Return reconstructed, post_reconstructed, full_latents | |
return post_reconstructed, reconstructed, full_latents | |