asadshahab's picture
initial
dd850a7
from typing import Dict, Any, Optional
import hashlib
import json
import torch
import pickle
import io
class AttentionCache:
def __init__(self, max_size: int = 10):
self.cache = {}
self.access_order = []
self.max_size = max_size
def get_key(self, prompt: str, max_tokens: int, model: str, temperature: float = 0.7) -> str:
"""Generate cache key from parameters"""
data = f"{prompt}_{max_tokens}_{model}_{temperature}"
return hashlib.md5(data.encode()).hexdigest()
def get(self, key: str) -> Optional[Dict[str, Any]]:
"""Retrieve cached data"""
if key in self.cache:
# Move to end (LRU)
self.access_order.remove(key)
self.access_order.append(key)
return self._deserialize(self.cache[key])
return None
def set(self, key: str, data: Dict[str, Any]):
"""Store data in cache"""
if len(self.cache) >= self.max_size:
# Remove least recently used
oldest = self.access_order.pop(0)
del self.cache[oldest]
self.cache[key] = self._serialize(data)
self.access_order.append(key)
def _serialize(self, data: Dict[str, Any]) -> bytes:
"""Serialize data for caching, handling torch tensors"""
serialized = {}
for key, value in data.items():
if isinstance(value, list) and len(value) > 0:
# Check if it's a list of dicts with tensors (attention matrices)
if isinstance(value[0], dict) and any(isinstance(v, torch.Tensor) for v in value[0].values()):
# Convert tensors to CPU and serialize
serialized_list = []
for item in value:
serialized_item = {}
for k, v in item.items():
if isinstance(v, torch.Tensor):
serialized_item[k] = v.cpu().numpy()
else:
serialized_item[k] = v
serialized_list.append(serialized_item)
serialized[key] = serialized_list
else:
serialized[key] = value
else:
serialized[key] = value
buffer = io.BytesIO()
pickle.dump(serialized, buffer)
return buffer.getvalue()
def _deserialize(self, data: bytes) -> Dict[str, Any]:
"""Deserialize data from cache, restoring torch tensors"""
buffer = io.BytesIO(data)
deserialized = pickle.load(buffer)
# Convert numpy arrays back to tensors where needed
for key, value in deserialized.items():
if isinstance(value, list) and len(value) > 0:
if isinstance(value[0], dict):
# Check if it contains numpy arrays (was tensors)
import numpy as np
for item in value:
for k, v in item.items():
if isinstance(v, np.ndarray):
item[k] = torch.from_numpy(v)
return deserialized
def clear(self):
"""Clear the entire cache"""
self.cache.clear()
self.access_order.clear()
def size(self) -> int:
"""Get current cache size"""
return len(self.cache)