DiT-XL-2-256-NdLinear / modeling_ndlinear_dit.py
Zhongfang Zhuang
Update modeling_ndlinear_dit.py
d154b53 verified
raw
history blame contribute delete
371 Bytes
# modeling_ndlinear_dit.py
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from mlp import NdMlp
from ndlinear import NdLinear
from models_hf import DiT, DiTConfig
class DiTConfig(PretrainedConfig):
model_type = "ndlinear_dit"
class DiT(PreTrainedModel):
config_class = DiTConfig
__all__ = ["DiT", "DiTConfig"]