Bird-MAE-Base / configuration_bird_mae.py
mwirth7's picture
Bird-MAE-Base
ed8dd1b verified
from transformers import PretrainedConfig
from typing import Literal
class BirdMAEConfig(PretrainedConfig):
"""This represents the Bird-MAE-Base config from the original paper"""
_auto_class = "AutoConfig"
def __init__(
self,
img_size_x: int = 512,
img_size_y: int = 128,
patch_size: int = 16,
in_chans: int = 1,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: int = 4,
pos_trainable: bool = False,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: float = None,
drop_rate: float = 0.0,
norm_layer_eps: float = 1e-6,
global_pool: Literal["cls", "mean"] | None = "mean",
**kwargs
):
super().__init__(**kwargs)
self.img_size_x = img_size_x
self.img_size_y = img_size_y
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.pos_trainable = pos_trainable
self.qkv_bias = qkv_bias
self.qk_norm = qk_norm
self.init_values = init_values
self.drop_rate = drop_rate
self.pos_drop_rate = drop_rate
self.attn_drop_rate = drop_rate
self.drop_path_rate = drop_rate
self.proj_drop_rate = drop_rate
self.norm_layer_eps = norm_layer_eps
self.global_pool = global_pool
# Calculated properties (useful for initializing the model)
self.num_patches_x = img_size_x // patch_size
self.num_patches_y = img_size_y // patch_size
self.num_patches = self.num_patches_x * self.num_patches_y
self.num_tokens = self.num_patches + 1