agentic-browser / config /model_config.py
anu151105's picture
Fix setup.py and update Docker configuration
4a0ad64
"""
Configuration for model management.
This module provides configuration for loading and managing models
from Hugging Face's model hub.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Literal
import os
ModelType = Literal["text-generation", "text-embedding", "vision", "multimodal"]
DeviceType = Literal["auto", "cpu", "cuda"]
@dataclass
class ModelConfig:
"""Configuration for a single model."""
model_id: str
model_path: str
model_type: ModelType
device: DeviceType = "auto"
quantize: bool = True
use_safetensors: bool = True
trust_remote_code: bool = True
description: str = ""
size_gb: float = 0.0 # Approximate size in GB
recommended: bool = False
# Available models with their configurations
DEFAULT_MODELS = {
# Lightweight models (under 2GB)
"tiny-llama": ModelConfig(
model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
model_path="./models/tiny-llama-1.1b-chat",
model_type="text-generation",
quantize=True,
description="Very small and fast model, good for quick testing",
size_gb=1.1,
recommended=True
),
# Medium models (2-10GB)
"mistral-7b": ModelConfig(
model_id="TheBloke/Mistral-7B-Instruct-v0.1-GPTQ",
model_path="./models/mistral-7b-instruct-gptq",
model_type="text-generation",
quantize=True,
description="Good balance of performance and resource usage",
size_gb=4.0,
recommended=True
),
"llama2-7b": ModelConfig(
model_id="meta-llama/Llama-2-7b-chat-hf",
model_path="./models/llama2-7b-chat",
model_type="text-generation",
description="High quality 7B parameter model from Meta",
size_gb=13.0
),
# Embedding models
"all-mpnet-base-v2": ModelConfig(
model_id="sentence-transformers/all-mpnet-base-v2",
model_path="./models/all-mpnet-base-v2",
model_type="text-embedding",
description="General purpose sentence transformer, good balance of speed and quality",
size_gb=0.4,
recommended=True
),
"bge-small-en": ModelConfig(
model_id="BAAI/bge-small-en-v1.5",
model_path="./models/bge-small-en",
model_type="text-embedding",
description="Small but powerful embedding model",
size_gb=0.13
),
# Larger models (10GB+)
"mixtral-8x7b": ModelConfig(
model_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
model_path="./models/mixtral-8x7b-instruct",
model_type="text-generation",
description="Very powerful 8x7B MoE model, requires significant resources",
size_gb=85.0
),
"llama2-13b": ModelConfig(
model_id="meta-llama/Llama-2-13b-chat-hf",
model_path="./models/llama2-13b-chat",
model_type="text-generation",
description="High quality 13B parameter model, better reasoning capabilities",
size_gb=26.0
)
}
def get_model_config(model_name: str) -> Optional[ModelConfig]:
"""Get configuration for a specific model."""
return DEFAULT_MODELS.get(model_name)
def list_available_models() -> List[str]:
"""List all available model names."""
return list(DEFAULT_MODELS.keys())
def get_model_path(model_name: str) -> str:
"""Get the local path for a model, downloading it if necessary."""
config = get_model_config(model_name)
if not config:
raise ValueError(f"Unknown model: {model_name}")
# Create model directory if it doesn't exist
os.makedirs(config.model_path, exist_ok=True)
# If model files don't exist, download them
if not os.path.exists(os.path.join(config.model_path, "config.json")):
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=config.model_id,
local_dir=config.model_path,
local_dir_use_symlinks=True,
ignore_patterns=["*.h5", "*.ot", "*.msgpack"],
)
return config.model_path