# modeling_ndlinear_dit.py | |
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, PretrainedConfig | |
from mlp import NdMlp | |
from ndlinear import NdLinear | |
from models_hf import DiT, DiTConfig | |
class DiTConfig(PretrainedConfig): | |
model_type = "ndlinear_dit" | |
class DiT(PreTrainedModel): | |
config_class = DiTConfig | |
__all__ = ["DiT", "DiTConfig"] | |