Edit model card

Mistral-SUPRA

This model was initialized from the weights of the Mistral-7B transformer model and up-trained into a linear RNN.

This is an accompanying model of our paper Linearizing Large Language Models, where we detail our process of converting a softmax transformer into a linear transformer, which at inference time can function as both a transformer and a recurrent model. Our linear attention code can be found at https://github.com/TRI-ML/linear_open_lm/

We uptrain Mistral-7B on 100B tokens of RefinedWeb.

Model Details

Parameters Hidden Size Layers Vocab Size Sequence Length
7B 4096 32 32000 2048

Training Details

  • Mistral-SUPRA was trained using AWS SageMaker on 128 H100 80GB GPUs.
  • Training on 100B tokens finished in 1.5 days.
    Hyperparameter Value
    Precision bfloat16
    Optimizer AdamW
    Learning rate 3e-5
    LR cooldown end 1e-5
    Warmup steps 1000
    Batch size 2M
    QK norm False

Usage

This model was trained using OpenLM. The weights have been converted to be compatible with HuggingFace.

To use the model, you need to first pip install our fork of OpenLM.

pip install git+https://github.com/tri-ml/linear_open_lm.git

Import the OpenLM classes with

from open_lm.open_lm_hf import *

The model can then be loaded normally using AutoTokenizer and AutoModelForCausalLM as follows:

from open_lm.open_lm_hf import *
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("tri-ml/mistral-supra")
model = AutoModelForCausalLM.from_pretrained("tri-ml/mistral-supra")

inputs = tokenizer(["Machine learning is"], return_tensors="pt")
gen_kwargs = {"max_new_tokens": 50, "top_p": 0.8, "temperature": 0.8, "do_sample": True, "repetition_penalty": 1.1}
output = model.generate(inputs['input_ids'], **gen_kwargs)
output = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
print(output)
# Machine learning is a branch of artificial intelligence (AI) that enables computers to learn from experience without being explicitly programmed. Machine learning is used in a wide range of applications, including spam filtering, image recognition, speech recognition, and computer-based medical diagnosis

The Mistral-SUPRA model can be used both in parallel mode and in recurrent mode. If use_cache is set to False for model.generate(...), then it will use parallel mode; otherwise, it will use recurrent mode. The recurrent model uses xformers and requires the inputs and models to be loaded to GPU.

# Recurrent mode
output = model.to('cuda').generate(inputs['input_ids'].to('cuda'), use_cache=True, **gen_kwargs)

# Parallel mode
output = model.to('cuda').generate(inputs['input_ids'].to('cuda'), use_cache=False, **gen_kwargs)

Performance Evaluation

Our evaluations were done using the Eleuther LM Eval Harness repo.

Below we report the performance of Mistral-SUPRA compared to other similarly sized models.

HellaSwag PIQA Winogrande ARC-E ARC-C MMLU (5-shot)
Llama2-7B 76.0 79.1 69.1 76.3 46.3 45.9
Gemma-7B 80.7 81.9 73.7 81.1 53.2 62.9
Mistral-7B 81.0 82.1 74.0 80.9 53.8 62.4
RWKV5-1.7T-7B 73.0 78.6 72.9 75.8 45.6 34.9
Mamba-7B 77.9 81.0 71.8 77.5 46.7 33.3
Mistral-SUPRA 77.1 80.4 70.3 75.9 45.8 34.2

How to Cite

If you use this model, please cite our paper on Linearizing Large Language Models.

@article{Mercat2024Linearizing,
  title={Linearizing Large Language Models},
  author={Jean Mercat and Igor Vasiljevic and Sedrick Keh and Kushal Arora and Achal Dave and Adrien Gaidon and Thomas Kollar},
  year={2024},
  journal={arXiv preprint arXiv:2405.06640},
}

Citations

OpenLM

@misc{open_lm,
  author = {Gururangan, Suchin and Wortsman, Mitchell and Gadre, Samir Yitzhak and Dave, Achal and Kilian, Maciej and Shi, Weijia and Mercat, Jean and Smyrnis, Georgios and Ilharco, Gabriel and Jordan, Matt and Heckel, Reinhard and Dimakis, Alex and Farhadi, Ali and Shankar, Vaishaal and Schmidt, Ludwig},
  title = {{open_lm}:  a minimal but performative language modeling (LM) repository},
  year = {2023},
  note = {GitHub repository},
  url = {https://github.com/mlfoundations/open_lm/}
}
Downloads last month
583
Inference API (serverless) does not yet support openlm models for this pipeline type.

Dataset used to train TRI-ML/mistral-supra

Evaluation results