Upload folder using huggingface_hub
Browse files- fabric_state/checkpoint.pt +1 -1
- generation_config.json +4 -0
- model.safetensors +1 -1
- pico_decoder.py +268 -5
fabric_state/checkpoint.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 135543171
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f78990840c8b2a26e89eea5f5414a84a9e8a0c76b9637d3cac17ec22e5486678
|
3 |
size 135543171
|
generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"transformers_version": "4.48.3",
|
3 |
+
"vocab_size": 50304
|
4 |
+
}
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 45143592
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3084d44929c019203e308a3f500b8792ca69ff273c69edb7eb6a433268e540f9
|
3 |
size 45143592
|
pico_decoder.py
CHANGED
@@ -31,7 +31,8 @@ import torch
|
|
31 |
import torch.nn as nn
|
32 |
import torch.nn.functional as F
|
33 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
34 |
-
from transformers import PretrainedConfig, PreTrainedModel
|
|
|
35 |
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
|
36 |
|
37 |
try:
|
@@ -134,7 +135,7 @@ class RoPE(nn.Module):
|
|
134 |
Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
|
135 |
|
136 |
Note other implementations will use cos and sin directly, but using the complex
|
137 |
-
number representation is (probably
|
138 |
|
139 |
e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
|
140 |
"""
|
@@ -314,7 +315,7 @@ class Attention(nn.Module):
|
|
314 |
queries.contiguous(),
|
315 |
keys.contiguous(),
|
316 |
values.contiguous(),
|
317 |
-
attn_mask=mask.to(queries.dtype),
|
318 |
enable_gqa=apply_gqa,
|
319 |
)
|
320 |
|
@@ -556,9 +557,9 @@ class PicoDecoderHFConfig(PretrainedConfig):
|
|
556 |
return cls.from_dict(asdict(model_config))
|
557 |
|
558 |
|
559 |
-
class PicoDecoderHF(PreTrainedModel):
|
560 |
"""
|
561 |
-
HuggingFace wrapper for the Pico model.
|
562 |
|
563 |
Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
|
564 |
wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
|
@@ -571,10 +572,18 @@ class PicoDecoderHF(PreTrainedModel):
|
|
571 |
|
572 |
config_class = PicoDecoderHFConfig
|
573 |
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
|
|
574 |
|
575 |
def __init__(self, config: PicoDecoderHFConfig):
|
576 |
super().__init__(config)
|
577 |
self.pico_decoder = PicoDecoder(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
578 |
|
579 |
def forward(
|
580 |
self,
|
@@ -601,8 +610,262 @@ class PicoDecoderHF(PreTrainedModel):
|
|
601 |
logits=logits,
|
602 |
)
|
603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
|
605 |
# Register for auto classes
|
606 |
PicoDecoderHFConfig.register_for_auto_class()
|
607 |
PicoDecoderHF.register_for_auto_class("AutoModel")
|
608 |
PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
import torch.nn as nn
|
32 |
import torch.nn.functional as F
|
33 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
34 |
+
from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
|
35 |
+
from transformers.generation import GenerationConfig
|
36 |
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
|
37 |
|
38 |
try:
|
|
|
135 |
Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
|
136 |
|
137 |
Note other implementations will use cos and sin directly, but using the complex
|
138 |
+
number representation is (probably) more efficient:
|
139 |
|
140 |
e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
|
141 |
"""
|
|
|
315 |
queries.contiguous(),
|
316 |
keys.contiguous(),
|
317 |
values.contiguous(),
|
318 |
+
attn_mask=mask.to(queries.dtype) if mask is not None else None,
|
319 |
enable_gqa=apply_gqa,
|
320 |
)
|
321 |
|
|
|
557 |
return cls.from_dict(asdict(model_config))
|
558 |
|
559 |
|
560 |
+
class PicoDecoderHF(PreTrainedModel, GenerationMixin):
|
561 |
"""
|
562 |
+
HuggingFace wrapper for the Pico model with generation support.
|
563 |
|
564 |
Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
|
565 |
wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
|
|
|
572 |
|
573 |
config_class = PicoDecoderHFConfig
|
574 |
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
575 |
+
main_input_name = "input_ids"
|
576 |
|
577 |
def __init__(self, config: PicoDecoderHFConfig):
|
578 |
super().__init__(config)
|
579 |
self.pico_decoder = PicoDecoder(config)
|
580 |
+
# Initialize generation config with defaults
|
581 |
+
self.generation_config = GenerationConfig()
|
582 |
+
# Set some reasonable defaults for the model
|
583 |
+
if hasattr(config, "max_position_embeddings"):
|
584 |
+
self.generation_config.max_length = config.max_position_embeddings
|
585 |
+
if hasattr(config, "vocab_size"):
|
586 |
+
self.generation_config.vocab_size = config.vocab_size
|
587 |
|
588 |
def forward(
|
589 |
self,
|
|
|
610 |
logits=logits,
|
611 |
)
|
612 |
|
613 |
+
def prepare_inputs_for_generation(
|
614 |
+
self,
|
615 |
+
input_ids: torch.LongTensor,
|
616 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
617 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
618 |
+
**kwargs,
|
619 |
+
) -> Dict[str, Any]:
|
620 |
+
"""
|
621 |
+
Prepare inputs for generation.
|
622 |
+
|
623 |
+
Args:
|
624 |
+
input_ids: Input token IDs
|
625 |
+
past_key_values: Cached key-value pairs from previous forward passes
|
626 |
+
attention_mask: Attention mask for the input
|
627 |
+
**kwargs: Additional arguments
|
628 |
+
|
629 |
+
Returns:
|
630 |
+
Dictionary containing prepared inputs
|
631 |
+
"""
|
632 |
+
# If we have past_key_values, we only need the last token
|
633 |
+
if past_key_values is not None:
|
634 |
+
input_ids = input_ids[:, -1:]
|
635 |
+
|
636 |
+
return {
|
637 |
+
"input_ids": input_ids,
|
638 |
+
"past_key_values": past_key_values,
|
639 |
+
"use_cache": True,
|
640 |
+
}
|
641 |
+
|
642 |
+
def get_input_embeddings(self):
|
643 |
+
"""Get the input embeddings layer."""
|
644 |
+
return self.pico_decoder.embedding_proj
|
645 |
+
|
646 |
+
def set_input_embeddings(self, value):
|
647 |
+
"""Set the input embeddings layer."""
|
648 |
+
self.pico_decoder.embedding_proj = value
|
649 |
+
|
650 |
+
def get_output_embeddings(self):
|
651 |
+
"""Get the output embeddings layer."""
|
652 |
+
return self.pico_decoder.de_embedding_proj
|
653 |
+
|
654 |
+
def set_output_embeddings(self, value):
|
655 |
+
"""Set the output embeddings layer."""
|
656 |
+
self.pico_decoder.de_embedding_proj = value
|
657 |
+
|
658 |
+
def get_lm_head(self):
|
659 |
+
"""Get the language model head."""
|
660 |
+
return self.pico_decoder.de_embedding_proj
|
661 |
+
|
662 |
+
def can_generate(self) -> bool:
|
663 |
+
"""Check if the model can generate text."""
|
664 |
+
return True
|
665 |
+
|
666 |
+
@property
|
667 |
+
def is_encoder_decoder(self) -> bool:
|
668 |
+
"""Check if the model is an encoder-decoder model."""
|
669 |
+
return False
|
670 |
+
|
671 |
+
@property
|
672 |
+
def can_use_cache(self) -> bool:
|
673 |
+
"""Check if the model can use KV cache."""
|
674 |
+
return True
|
675 |
+
|
676 |
+
def resize_token_embeddings(
|
677 |
+
self, new_num_tokens: Optional[int] = None
|
678 |
+
) -> torch.nn.Embedding:
|
679 |
+
"""Resize token embeddings."""
|
680 |
+
old_embeddings = self.get_input_embeddings()
|
681 |
+
if new_num_tokens is None:
|
682 |
+
new_num_tokens = old_embeddings.num_embeddings
|
683 |
+
|
684 |
+
new_embeddings = torch.nn.Embedding(
|
685 |
+
new_num_tokens, old_embeddings.embedding_dim
|
686 |
+
)
|
687 |
+
new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
|
688 |
+
old_embeddings.weight.data
|
689 |
+
)
|
690 |
+
|
691 |
+
self.pico_decoder.embedding_proj = new_embeddings
|
692 |
+
self.pico_decoder.de_embedding_proj = torch.nn.Linear(
|
693 |
+
old_embeddings.embedding_dim, new_num_tokens, bias=False
|
694 |
+
)
|
695 |
+
|
696 |
+
return new_embeddings
|
697 |
+
|
698 |
|
699 |
# Register for auto classes
|
700 |
PicoDecoderHFConfig.register_for_auto_class()
|
701 |
PicoDecoderHF.register_for_auto_class("AutoModel")
|
702 |
PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
|
703 |
+
|
704 |
+
|
705 |
+
########################################################
|
706 |
+
#
|
707 |
+
# New PicoDecoderForCausalLM class for generation support
|
708 |
+
#
|
709 |
+
########################################################
|
710 |
+
|
711 |
+
|
712 |
+
class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
|
713 |
+
"""
|
714 |
+
PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
|
715 |
+
|
716 |
+
This class is designed to work with existing checkpoints and provides full generation support.
|
717 |
+
It inherits from the right base classes that HuggingFace expects for text generation.
|
718 |
+
"""
|
719 |
+
|
720 |
+
config_class = PicoDecoderHFConfig
|
721 |
+
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
722 |
+
main_input_name = "input_ids"
|
723 |
+
|
724 |
+
def __init__(self, config: PicoDecoderHFConfig):
|
725 |
+
super().__init__(config)
|
726 |
+
self.pico_decoder = PicoDecoder(config)
|
727 |
+
# Initialize generation config with defaults
|
728 |
+
self.generation_config = GenerationConfig()
|
729 |
+
# Set some reasonable defaults for the model
|
730 |
+
if hasattr(config, "max_position_embeddings"):
|
731 |
+
self.generation_config.max_length = config.max_position_embeddings
|
732 |
+
if hasattr(config, "vocab_size"):
|
733 |
+
self.generation_config.vocab_size = config.vocab_size
|
734 |
+
|
735 |
+
def forward(
|
736 |
+
self,
|
737 |
+
input_ids: torch.Tensor,
|
738 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
739 |
+
use_cache: bool = False,
|
740 |
+
**kwargs,
|
741 |
+
) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
|
742 |
+
"""Forward pass for text generation."""
|
743 |
+
logits, past_key_values = self.pico_decoder(
|
744 |
+
input_ids, past_key_values, use_cache
|
745 |
+
)
|
746 |
+
if use_cache:
|
747 |
+
return CausalLMOutputWithPast(
|
748 |
+
logits=logits,
|
749 |
+
past_key_values=past_key_values,
|
750 |
+
)
|
751 |
+
else:
|
752 |
+
return CausalLMOutput(
|
753 |
+
logits=logits,
|
754 |
+
)
|
755 |
+
|
756 |
+
def prepare_inputs_for_generation(
|
757 |
+
self,
|
758 |
+
input_ids: torch.LongTensor,
|
759 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
760 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
761 |
+
**kwargs,
|
762 |
+
) -> Dict[str, Any]:
|
763 |
+
"""Prepare inputs for generation."""
|
764 |
+
# If we have past_key_values, we only need the last token
|
765 |
+
if past_key_values is not None:
|
766 |
+
input_ids = input_ids[:, -1:]
|
767 |
+
|
768 |
+
return {
|
769 |
+
"input_ids": input_ids,
|
770 |
+
"past_key_values": past_key_values,
|
771 |
+
"use_cache": True,
|
772 |
+
}
|
773 |
+
|
774 |
+
def get_input_embeddings(self):
|
775 |
+
"""Get the input embeddings layer."""
|
776 |
+
return self.pico_decoder.embedding_proj
|
777 |
+
|
778 |
+
def set_input_embeddings(self, value):
|
779 |
+
"""Set the input embeddings layer."""
|
780 |
+
self.pico_decoder.embedding_proj = value
|
781 |
+
|
782 |
+
def get_output_embeddings(self):
|
783 |
+
"""Get the output embeddings layer."""
|
784 |
+
return self.pico_decoder.de_embedding_proj
|
785 |
+
|
786 |
+
def set_output_embeddings(self, value):
|
787 |
+
"""Set the output embeddings layer."""
|
788 |
+
self.pico_decoder.de_embedding_proj = value
|
789 |
+
|
790 |
+
def get_lm_head(self):
|
791 |
+
"""Get the language model head."""
|
792 |
+
return self.pico_decoder.de_embedding_proj
|
793 |
+
|
794 |
+
def can_generate(self) -> bool:
|
795 |
+
"""Check if the model can generate text."""
|
796 |
+
return True
|
797 |
+
|
798 |
+
@property
|
799 |
+
def is_encoder_decoder(self) -> bool:
|
800 |
+
"""Check if the model is an encoder-decoder model."""
|
801 |
+
return False
|
802 |
+
|
803 |
+
@property
|
804 |
+
def can_use_cache(self) -> bool:
|
805 |
+
"""Check if the model can use KV cache."""
|
806 |
+
return True
|
807 |
+
|
808 |
+
def resize_token_embeddings(
|
809 |
+
self, new_num_tokens: Optional[int] = None
|
810 |
+
) -> torch.nn.Embedding:
|
811 |
+
"""Resize token embeddings."""
|
812 |
+
old_embeddings = self.get_input_embeddings()
|
813 |
+
if new_num_tokens is None:
|
814 |
+
new_num_tokens = old_embeddings.num_embeddings
|
815 |
+
|
816 |
+
new_embeddings = torch.nn.Embedding(
|
817 |
+
new_num_tokens, old_embeddings.embedding_dim
|
818 |
+
)
|
819 |
+
new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
|
820 |
+
old_embeddings.weight.data
|
821 |
+
)
|
822 |
+
|
823 |
+
self.pico_decoder.embedding_proj = new_embeddings
|
824 |
+
self.pico_decoder.de_embedding_proj = torch.nn.Linear(
|
825 |
+
old_embeddings.embedding_dim, new_num_tokens, bias=False
|
826 |
+
)
|
827 |
+
|
828 |
+
return new_embeddings
|
829 |
+
|
830 |
+
@classmethod
|
831 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
832 |
+
"""
|
833 |
+
Load a pretrained model from a checkpoint.
|
834 |
+
|
835 |
+
This method handles loading from both the old PicoDecoderHF format and the new format.
|
836 |
+
"""
|
837 |
+
# First try to load with the new class
|
838 |
+
try:
|
839 |
+
return super().from_pretrained(
|
840 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
841 |
+
)
|
842 |
+
except Exception as e:
|
843 |
+
print(f"Failed to load with new class: {e}")
|
844 |
+
print("Attempting to load with legacy class and convert...")
|
845 |
+
|
846 |
+
# Try to load with the old class and convert
|
847 |
+
try:
|
848 |
+
from transformers import AutoModel
|
849 |
+
|
850 |
+
old_model = AutoModel.from_pretrained(
|
851 |
+
pretrained_model_name_or_path,
|
852 |
+
trust_remote_code=True,
|
853 |
+
*model_args,
|
854 |
+
**kwargs,
|
855 |
+
)
|
856 |
+
|
857 |
+
# Create new model instance
|
858 |
+
new_model = cls(old_model.config)
|
859 |
+
|
860 |
+
# Copy state dict
|
861 |
+
new_model.load_state_dict(old_model.state_dict(), strict=False)
|
862 |
+
|
863 |
+
return new_model
|
864 |
+
|
865 |
+
except Exception as e2:
|
866 |
+
print(f"Failed to convert from legacy format: {e2}")
|
867 |
+
raise e
|
868 |
+
|
869 |
+
|
870 |
+
# Register the new class
|
871 |
+
PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|