ThomasTheMaker commited on
Commit
a620651
·
verified ·
1 Parent(s): 5b29f06

Upload folder using huggingface_hub

Browse files
fabric_state/checkpoint.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e42d749796c6550ffb318da21c493f94df7f0c48120ac9ecbbd0eb6402fc67ff
3
  size 135543171
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f78990840c8b2a26e89eea5f5414a84a9e8a0c76b9637d3cac17ec22e5486678
3
  size 135543171
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3d7be9a1e9b585a92821668324e20d977c23d51c04b2ade7610f764f62efe829
3
  size 45143592
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3084d44929c019203e308a3f500b8792ca69ff273c69edb7eb6a433268e540f9
3
  size 45143592
pico_decoder.py CHANGED
@@ -31,7 +31,8 @@ import torch
31
  import torch.nn as nn
32
  import torch.nn.functional as F
33
  from torch.nn.attention import SDPBackend, sdpa_kernel
34
- from transformers import PretrainedConfig, PreTrainedModel
 
35
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
36
 
37
  try:
@@ -134,7 +135,7 @@ class RoPE(nn.Module):
134
  Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
135
 
136
  Note other implementations will use cos and sin directly, but using the complex
137
- number representation is (probably?) more efficient:
138
 
139
  e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
140
  """
@@ -314,7 +315,7 @@ class Attention(nn.Module):
314
  queries.contiguous(),
315
  keys.contiguous(),
316
  values.contiguous(),
317
- attn_mask=mask.to(queries.dtype),
318
  enable_gqa=apply_gqa,
319
  )
320
 
@@ -556,9 +557,9 @@ class PicoDecoderHFConfig(PretrainedConfig):
556
  return cls.from_dict(asdict(model_config))
557
 
558
 
559
- class PicoDecoderHF(PreTrainedModel):
560
  """
561
- HuggingFace wrapper for the Pico model.
562
 
563
  Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
564
  wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
@@ -571,10 +572,18 @@ class PicoDecoderHF(PreTrainedModel):
571
 
572
  config_class = PicoDecoderHFConfig
573
  _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
 
574
 
575
  def __init__(self, config: PicoDecoderHFConfig):
576
  super().__init__(config)
577
  self.pico_decoder = PicoDecoder(config)
 
 
 
 
 
 
 
578
 
579
  def forward(
580
  self,
@@ -601,8 +610,262 @@ class PicoDecoderHF(PreTrainedModel):
601
  logits=logits,
602
  )
603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
605
  # Register for auto classes
606
  PicoDecoderHFConfig.register_for_auto_class()
607
  PicoDecoderHF.register_for_auto_class("AutoModel")
608
  PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  import torch.nn as nn
32
  import torch.nn.functional as F
33
  from torch.nn.attention import SDPBackend, sdpa_kernel
34
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
35
+ from transformers.generation import GenerationConfig
36
  from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
37
 
38
  try:
 
135
  Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
136
 
137
  Note other implementations will use cos and sin directly, but using the complex
138
+ number representation is (probably) more efficient:
139
 
140
  e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
141
  """
 
315
  queries.contiguous(),
316
  keys.contiguous(),
317
  values.contiguous(),
318
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
319
  enable_gqa=apply_gqa,
320
  )
321
 
 
557
  return cls.from_dict(asdict(model_config))
558
 
559
 
560
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
561
  """
562
+ HuggingFace wrapper for the Pico model with generation support.
563
 
564
  Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
565
  wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
 
572
 
573
  config_class = PicoDecoderHFConfig
574
  _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
575
+ main_input_name = "input_ids"
576
 
577
  def __init__(self, config: PicoDecoderHFConfig):
578
  super().__init__(config)
579
  self.pico_decoder = PicoDecoder(config)
580
+ # Initialize generation config with defaults
581
+ self.generation_config = GenerationConfig()
582
+ # Set some reasonable defaults for the model
583
+ if hasattr(config, "max_position_embeddings"):
584
+ self.generation_config.max_length = config.max_position_embeddings
585
+ if hasattr(config, "vocab_size"):
586
+ self.generation_config.vocab_size = config.vocab_size
587
 
588
  def forward(
589
  self,
 
610
  logits=logits,
611
  )
612
 
613
+ def prepare_inputs_for_generation(
614
+ self,
615
+ input_ids: torch.LongTensor,
616
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
617
+ attention_mask: Optional[torch.LongTensor] = None,
618
+ **kwargs,
619
+ ) -> Dict[str, Any]:
620
+ """
621
+ Prepare inputs for generation.
622
+
623
+ Args:
624
+ input_ids: Input token IDs
625
+ past_key_values: Cached key-value pairs from previous forward passes
626
+ attention_mask: Attention mask for the input
627
+ **kwargs: Additional arguments
628
+
629
+ Returns:
630
+ Dictionary containing prepared inputs
631
+ """
632
+ # If we have past_key_values, we only need the last token
633
+ if past_key_values is not None:
634
+ input_ids = input_ids[:, -1:]
635
+
636
+ return {
637
+ "input_ids": input_ids,
638
+ "past_key_values": past_key_values,
639
+ "use_cache": True,
640
+ }
641
+
642
+ def get_input_embeddings(self):
643
+ """Get the input embeddings layer."""
644
+ return self.pico_decoder.embedding_proj
645
+
646
+ def set_input_embeddings(self, value):
647
+ """Set the input embeddings layer."""
648
+ self.pico_decoder.embedding_proj = value
649
+
650
+ def get_output_embeddings(self):
651
+ """Get the output embeddings layer."""
652
+ return self.pico_decoder.de_embedding_proj
653
+
654
+ def set_output_embeddings(self, value):
655
+ """Set the output embeddings layer."""
656
+ self.pico_decoder.de_embedding_proj = value
657
+
658
+ def get_lm_head(self):
659
+ """Get the language model head."""
660
+ return self.pico_decoder.de_embedding_proj
661
+
662
+ def can_generate(self) -> bool:
663
+ """Check if the model can generate text."""
664
+ return True
665
+
666
+ @property
667
+ def is_encoder_decoder(self) -> bool:
668
+ """Check if the model is an encoder-decoder model."""
669
+ return False
670
+
671
+ @property
672
+ def can_use_cache(self) -> bool:
673
+ """Check if the model can use KV cache."""
674
+ return True
675
+
676
+ def resize_token_embeddings(
677
+ self, new_num_tokens: Optional[int] = None
678
+ ) -> torch.nn.Embedding:
679
+ """Resize token embeddings."""
680
+ old_embeddings = self.get_input_embeddings()
681
+ if new_num_tokens is None:
682
+ new_num_tokens = old_embeddings.num_embeddings
683
+
684
+ new_embeddings = torch.nn.Embedding(
685
+ new_num_tokens, old_embeddings.embedding_dim
686
+ )
687
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
688
+ old_embeddings.weight.data
689
+ )
690
+
691
+ self.pico_decoder.embedding_proj = new_embeddings
692
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
693
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
694
+ )
695
+
696
+ return new_embeddings
697
+
698
 
699
  # Register for auto classes
700
  PicoDecoderHFConfig.register_for_auto_class()
701
  PicoDecoderHF.register_for_auto_class("AutoModel")
702
  PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
703
+
704
+
705
+ ########################################################
706
+ #
707
+ # New PicoDecoderForCausalLM class for generation support
708
+ #
709
+ ########################################################
710
+
711
+
712
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
713
+ """
714
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
715
+
716
+ This class is designed to work with existing checkpoints and provides full generation support.
717
+ It inherits from the right base classes that HuggingFace expects for text generation.
718
+ """
719
+
720
+ config_class = PicoDecoderHFConfig
721
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
722
+ main_input_name = "input_ids"
723
+
724
+ def __init__(self, config: PicoDecoderHFConfig):
725
+ super().__init__(config)
726
+ self.pico_decoder = PicoDecoder(config)
727
+ # Initialize generation config with defaults
728
+ self.generation_config = GenerationConfig()
729
+ # Set some reasonable defaults for the model
730
+ if hasattr(config, "max_position_embeddings"):
731
+ self.generation_config.max_length = config.max_position_embeddings
732
+ if hasattr(config, "vocab_size"):
733
+ self.generation_config.vocab_size = config.vocab_size
734
+
735
+ def forward(
736
+ self,
737
+ input_ids: torch.Tensor,
738
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
739
+ use_cache: bool = False,
740
+ **kwargs,
741
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
742
+ """Forward pass for text generation."""
743
+ logits, past_key_values = self.pico_decoder(
744
+ input_ids, past_key_values, use_cache
745
+ )
746
+ if use_cache:
747
+ return CausalLMOutputWithPast(
748
+ logits=logits,
749
+ past_key_values=past_key_values,
750
+ )
751
+ else:
752
+ return CausalLMOutput(
753
+ logits=logits,
754
+ )
755
+
756
+ def prepare_inputs_for_generation(
757
+ self,
758
+ input_ids: torch.LongTensor,
759
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
760
+ attention_mask: Optional[torch.LongTensor] = None,
761
+ **kwargs,
762
+ ) -> Dict[str, Any]:
763
+ """Prepare inputs for generation."""
764
+ # If we have past_key_values, we only need the last token
765
+ if past_key_values is not None:
766
+ input_ids = input_ids[:, -1:]
767
+
768
+ return {
769
+ "input_ids": input_ids,
770
+ "past_key_values": past_key_values,
771
+ "use_cache": True,
772
+ }
773
+
774
+ def get_input_embeddings(self):
775
+ """Get the input embeddings layer."""
776
+ return self.pico_decoder.embedding_proj
777
+
778
+ def set_input_embeddings(self, value):
779
+ """Set the input embeddings layer."""
780
+ self.pico_decoder.embedding_proj = value
781
+
782
+ def get_output_embeddings(self):
783
+ """Get the output embeddings layer."""
784
+ return self.pico_decoder.de_embedding_proj
785
+
786
+ def set_output_embeddings(self, value):
787
+ """Set the output embeddings layer."""
788
+ self.pico_decoder.de_embedding_proj = value
789
+
790
+ def get_lm_head(self):
791
+ """Get the language model head."""
792
+ return self.pico_decoder.de_embedding_proj
793
+
794
+ def can_generate(self) -> bool:
795
+ """Check if the model can generate text."""
796
+ return True
797
+
798
+ @property
799
+ def is_encoder_decoder(self) -> bool:
800
+ """Check if the model is an encoder-decoder model."""
801
+ return False
802
+
803
+ @property
804
+ def can_use_cache(self) -> bool:
805
+ """Check if the model can use KV cache."""
806
+ return True
807
+
808
+ def resize_token_embeddings(
809
+ self, new_num_tokens: Optional[int] = None
810
+ ) -> torch.nn.Embedding:
811
+ """Resize token embeddings."""
812
+ old_embeddings = self.get_input_embeddings()
813
+ if new_num_tokens is None:
814
+ new_num_tokens = old_embeddings.num_embeddings
815
+
816
+ new_embeddings = torch.nn.Embedding(
817
+ new_num_tokens, old_embeddings.embedding_dim
818
+ )
819
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
820
+ old_embeddings.weight.data
821
+ )
822
+
823
+ self.pico_decoder.embedding_proj = new_embeddings
824
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
825
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
826
+ )
827
+
828
+ return new_embeddings
829
+
830
+ @classmethod
831
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
832
+ """
833
+ Load a pretrained model from a checkpoint.
834
+
835
+ This method handles loading from both the old PicoDecoderHF format and the new format.
836
+ """
837
+ # First try to load with the new class
838
+ try:
839
+ return super().from_pretrained(
840
+ pretrained_model_name_or_path, *model_args, **kwargs
841
+ )
842
+ except Exception as e:
843
+ print(f"Failed to load with new class: {e}")
844
+ print("Attempting to load with legacy class and convert...")
845
+
846
+ # Try to load with the old class and convert
847
+ try:
848
+ from transformers import AutoModel
849
+
850
+ old_model = AutoModel.from_pretrained(
851
+ pretrained_model_name_or_path,
852
+ trust_remote_code=True,
853
+ *model_args,
854
+ **kwargs,
855
+ )
856
+
857
+ # Create new model instance
858
+ new_model = cls(old_model.config)
859
+
860
+ # Copy state dict
861
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
862
+
863
+ return new_model
864
+
865
+ except Exception as e2:
866
+ print(f"Failed to convert from legacy format: {e2}")
867
+ raise e
868
+
869
+
870
+ # Register the new class
871
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")