|
from typing import Dict, Any, Optional |
|
import hashlib |
|
import json |
|
import torch |
|
import pickle |
|
import io |
|
|
|
class AttentionCache: |
|
def __init__(self, max_size: int = 10): |
|
self.cache = {} |
|
self.access_order = [] |
|
self.max_size = max_size |
|
|
|
def get_key(self, prompt: str, max_tokens: int, model: str, temperature: float = 0.7) -> str: |
|
"""Generate cache key from parameters""" |
|
data = f"{prompt}_{max_tokens}_{model}_{temperature}" |
|
return hashlib.md5(data.encode()).hexdigest() |
|
|
|
def get(self, key: str) -> Optional[Dict[str, Any]]: |
|
"""Retrieve cached data""" |
|
if key in self.cache: |
|
|
|
self.access_order.remove(key) |
|
self.access_order.append(key) |
|
return self._deserialize(self.cache[key]) |
|
return None |
|
|
|
def set(self, key: str, data: Dict[str, Any]): |
|
"""Store data in cache""" |
|
if len(self.cache) >= self.max_size: |
|
|
|
oldest = self.access_order.pop(0) |
|
del self.cache[oldest] |
|
|
|
self.cache[key] = self._serialize(data) |
|
self.access_order.append(key) |
|
|
|
def _serialize(self, data: Dict[str, Any]) -> bytes: |
|
"""Serialize data for caching, handling torch tensors""" |
|
serialized = {} |
|
for key, value in data.items(): |
|
if isinstance(value, list) and len(value) > 0: |
|
|
|
if isinstance(value[0], dict) and any(isinstance(v, torch.Tensor) for v in value[0].values()): |
|
|
|
serialized_list = [] |
|
for item in value: |
|
serialized_item = {} |
|
for k, v in item.items(): |
|
if isinstance(v, torch.Tensor): |
|
serialized_item[k] = v.cpu().numpy() |
|
else: |
|
serialized_item[k] = v |
|
serialized_list.append(serialized_item) |
|
serialized[key] = serialized_list |
|
else: |
|
serialized[key] = value |
|
else: |
|
serialized[key] = value |
|
|
|
buffer = io.BytesIO() |
|
pickle.dump(serialized, buffer) |
|
return buffer.getvalue() |
|
|
|
def _deserialize(self, data: bytes) -> Dict[str, Any]: |
|
"""Deserialize data from cache, restoring torch tensors""" |
|
buffer = io.BytesIO(data) |
|
deserialized = pickle.load(buffer) |
|
|
|
|
|
for key, value in deserialized.items(): |
|
if isinstance(value, list) and len(value) > 0: |
|
if isinstance(value[0], dict): |
|
|
|
import numpy as np |
|
for item in value: |
|
for k, v in item.items(): |
|
if isinstance(v, np.ndarray): |
|
item[k] = torch.from_numpy(v) |
|
|
|
return deserialized |
|
|
|
def clear(self): |
|
"""Clear the entire cache""" |
|
self.cache.clear() |
|
self.access_order.clear() |
|
|
|
def size(self) -> int: |
|
"""Get current cache size""" |
|
return len(self.cache) |