Spaces:
Running
Running
""" | |
Configuration for model management. | |
This module provides configuration for loading and managing models | |
from Hugging Face's model hub. | |
""" | |
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Literal | |
import os | |
ModelType = Literal["text-generation", "text-embedding", "vision", "multimodal"] | |
DeviceType = Literal["auto", "cpu", "cuda"] | |
class ModelConfig: | |
"""Configuration for a single model.""" | |
model_id: str | |
model_path: str | |
model_type: ModelType | |
device: DeviceType = "auto" | |
quantize: bool = True | |
use_safetensors: bool = True | |
trust_remote_code: bool = True | |
description: str = "" | |
size_gb: float = 0.0 # Approximate size in GB | |
recommended: bool = False | |
# Available models with their configurations | |
DEFAULT_MODELS = { | |
# Lightweight models (under 2GB) | |
"tiny-llama": ModelConfig( | |
model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
model_path="./models/tiny-llama-1.1b-chat", | |
model_type="text-generation", | |
quantize=True, | |
description="Very small and fast model, good for quick testing", | |
size_gb=1.1, | |
recommended=True | |
), | |
# Medium models (2-10GB) | |
"mistral-7b": ModelConfig( | |
model_id="TheBloke/Mistral-7B-Instruct-v0.1-GPTQ", | |
model_path="./models/mistral-7b-instruct-gptq", | |
model_type="text-generation", | |
quantize=True, | |
description="Good balance of performance and resource usage", | |
size_gb=4.0, | |
recommended=True | |
), | |
"llama2-7b": ModelConfig( | |
model_id="meta-llama/Llama-2-7b-chat-hf", | |
model_path="./models/llama2-7b-chat", | |
model_type="text-generation", | |
description="High quality 7B parameter model from Meta", | |
size_gb=13.0 | |
), | |
# Embedding models | |
"all-mpnet-base-v2": ModelConfig( | |
model_id="sentence-transformers/all-mpnet-base-v2", | |
model_path="./models/all-mpnet-base-v2", | |
model_type="text-embedding", | |
description="General purpose sentence transformer, good balance of speed and quality", | |
size_gb=0.4, | |
recommended=True | |
), | |
"bge-small-en": ModelConfig( | |
model_id="BAAI/bge-small-en-v1.5", | |
model_path="./models/bge-small-en", | |
model_type="text-embedding", | |
description="Small but powerful embedding model", | |
size_gb=0.13 | |
), | |
# Larger models (10GB+) | |
"mixtral-8x7b": ModelConfig( | |
model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
model_path="./models/mixtral-8x7b-instruct", | |
model_type="text-generation", | |
description="Very powerful 8x7B MoE model, requires significant resources", | |
size_gb=85.0 | |
), | |
"llama2-13b": ModelConfig( | |
model_id="meta-llama/Llama-2-13b-chat-hf", | |
model_path="./models/llama2-13b-chat", | |
model_type="text-generation", | |
description="High quality 13B parameter model, better reasoning capabilities", | |
size_gb=26.0 | |
) | |
} | |
def get_model_config(model_name: str) -> Optional[ModelConfig]: | |
"""Get configuration for a specific model.""" | |
return DEFAULT_MODELS.get(model_name) | |
def list_available_models() -> List[str]: | |
"""List all available model names.""" | |
return list(DEFAULT_MODELS.keys()) | |
def get_model_path(model_name: str) -> str: | |
"""Get the local path for a model, downloading it if necessary.""" | |
config = get_model_config(model_name) | |
if not config: | |
raise ValueError(f"Unknown model: {model_name}") | |
# Create model directory if it doesn't exist | |
os.makedirs(config.model_path, exist_ok=True) | |
# If model files don't exist, download them | |
if not os.path.exists(os.path.join(config.model_path, "config.json")): | |
from huggingface_hub import snapshot_download | |
snapshot_download( | |
repo_id=config.model_id, | |
local_dir=config.model_path, | |
local_dir_use_symlinks=True, | |
ignore_patterns=["*.h5", "*.ot", "*.msgpack"], | |
) | |
return config.model_path | |