MouadHsb commited on
Commit
cee458a
·
1 Parent(s): c1911d8

Switching to api for embedding

Browse files
app/services/embedding_service copy.py CHANGED
@@ -1,47 +1,61 @@
1
  import logging
2
  import numpy as np
3
- from typing import List, Dict, Any, Tuple
4
- from sentence_transformers import SentenceTransformer
5
  import torch
6
- import os
7
-
8
- os.environ["PYTORCH_ENABLE_META_TENSORS"] = "0"
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
  class EmbeddingService:
13
- """Service for handling document embeddings using Sentence Transformers."""
14
 
15
- def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
16
  """
17
  Initialize the embedding system.
18
 
19
  Args:
20
- model_name: Name of the Sentence Transformers model to use. Default is "all-MiniLM-L6-v2" (80MB).
21
  """
22
  logger.info(f"Loading embedding model: {model_name}")
23
 
24
- # Explicitly set device to CPU to avoid meta tensor issue
25
  torch.set_grad_enabled(False)
26
- # With 16GB of RAM, we can afford to use standard loading without memory optimization
27
- # Force the model to load fully into memory without any meta tensors
28
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
29
-
30
- #########################################################
31
- torch.set_default_device("cpu")
32
- #########################################################
33
 
34
- self.model = SentenceTransformer(model_name, device="cpu")
35
-
36
- # Ensure model is fully materialized, not using meta tensors
37
- for param in self.model.parameters():
38
- if hasattr(param, 'is_meta') and param.is_meta:
39
- # Should not happen with environment variable set, but just in case
40
- raise RuntimeError("Meta tensors still detected despite disabling them")
41
 
42
- self.embedding_dim = self.model.get_sentence_embedding_dimension()
 
 
 
43
  logger.info(f"Embedding dimension: {self.embedding_dim}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
45
  def embed_documents(self, documents: List[Dict[str, Any]]) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
46
  """
47
  Embed a list of documents.
@@ -54,7 +68,38 @@ class EmbeddingService:
54
  """
55
  texts = [doc["text"] for doc in documents]
56
  logger.info(f"Embedding {len(texts)} documents...")
57
- embeddings = self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  return embeddings, documents
60
 
@@ -68,7 +113,26 @@ class EmbeddingService:
68
  Returns:
69
  Query embedding array.
70
  """
71
- return self.model.encode([query], convert_to_numpy=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def get_model_info(self) -> Dict[str, Any]:
74
  """
@@ -77,21 +141,9 @@ class EmbeddingService:
77
  Returns:
78
  Dictionary with model information.
79
  """
80
- # Access the model attributes in a safer way
81
- try:
82
- model_name = self.model._model_config.get('name',
83
- self.model._model_config.get('model_name_or_path', 'unknown'))
84
- except:
85
- model_name = str(self.model) # Fallback to string representation
86
-
87
- try:
88
- max_seq_length = self.model.get_max_seq_length()
89
- except:
90
- max_seq_length = 512 # Default value if method not available
91
-
92
  return {
93
- "model_name": model_name,
94
  "dimension": self.embedding_dim,
95
- "max_seq_length": max_seq_length,
96
- "normalize_embeddings": getattr(self.model, "normalize_embeddings", True)
97
- }
 
1
  import logging
2
  import numpy as np
 
 
3
  import torch
4
+ from typing import List, Dict, Any, Tuple
5
+ from transformers import AutoModel, AutoTokenizer
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
  class EmbeddingService:
10
+ """Service for handling document embeddings using Hugging Face Transformers."""
11
 
12
+ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
13
  """
14
  Initialize the embedding system.
15
 
16
  Args:
17
+ model_name: Name of the model to use. Default is "sentence-transformers/all-MiniLM-L6-v2".
18
  """
19
  logger.info(f"Loading embedding model: {model_name}")
20
 
21
+ # Disable gradients for inference
22
  torch.set_grad_enabled(False)
 
 
 
 
 
 
 
23
 
24
+ # Load tokenizer and model
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ self.model = AutoModel.from_pretrained(model_name)
27
+
 
 
 
28
 
29
+ self.model.eval()
30
+
31
+ # Get embedding dimension from model config
32
+ self.embedding_dim = self.model.config.hidden_size
33
  logger.info(f"Embedding dimension: {self.embedding_dim}")
34
+
35
+ def _mean_pooling(self, model_output, attention_mask):
36
+ """
37
+ Perform mean pooling on token embeddings.
38
+
39
+ Args:
40
+ model_output: Output from the transformer model
41
+ attention_mask: Attention mask to avoid padding tokens
42
+
43
+ Returns:
44
+ Sentence embeddings
45
+ """
46
+ # First element of model_output contains token embeddings
47
+ token_embeddings = model_output[0]
48
+
49
+ # Expand attention mask from [batch_size, seq_length] to [batch_size, seq_length, hidden_size]
50
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
51
+
52
+ # Sum token embeddings and divide by the expanded mask
53
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
54
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
55
 
56
+ # Return mean-pooled embeddings
57
+ return sum_embeddings / sum_mask
58
+
59
  def embed_documents(self, documents: List[Dict[str, Any]]) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
60
  """
61
  Embed a list of documents.
 
68
  """
69
  texts = [doc["text"] for doc in documents]
70
  logger.info(f"Embedding {len(texts)} documents...")
71
+
72
+ # Process in batches to avoid OOM
73
+ batch_size = 8
74
+ all_embeddings = []
75
+
76
+ for i in range(0, len(texts), batch_size):
77
+ batch_texts = texts[i:i+batch_size]
78
+
79
+ # Tokenize batch
80
+ encoded_input = self.tokenizer(
81
+ batch_texts,
82
+ padding=True,
83
+ truncation=True,
84
+ max_length=512,
85
+ return_tensors='pt'
86
+ ).to("cpu")
87
+
88
+ # Compute token embeddings
89
+ with torch.no_grad():
90
+ model_output = self.model(**encoded_input)
91
+
92
+ # Apply mean pooling
93
+ batch_embeddings = self._mean_pooling(model_output, encoded_input['attention_mask'])
94
+
95
+ # Normalize embeddings
96
+ batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
97
+
98
+ # Convert to numpy and add to result
99
+ all_embeddings.append(batch_embeddings.cpu().numpy())
100
+
101
+ # Combine all batches
102
+ embeddings = np.vstack(all_embeddings)
103
 
104
  return embeddings, documents
105
 
 
113
  Returns:
114
  Query embedding array.
115
  """
116
+ # Tokenize query
117
+ encoded_input = self.tokenizer(
118
+ [query],
119
+ padding=True,
120
+ truncation=True,
121
+ max_length=512,
122
+ return_tensors='pt'
123
+ ).to("cpu")
124
+
125
+ # Compute token embeddings
126
+ with torch.no_grad():
127
+ model_output = self.model(**encoded_input)
128
+
129
+ # Apply mean pooling
130
+ embeddings = self._mean_pooling(model_output, encoded_input['attention_mask'])
131
+
132
+ # Normalize embeddings
133
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
134
+
135
+ return embeddings.cpu().numpy()
136
 
137
  def get_model_info(self) -> Dict[str, Any]:
138
  """
 
141
  Returns:
142
  Dictionary with model information.
143
  """
 
 
 
 
 
 
 
 
 
 
 
 
144
  return {
145
+ "model_name": self.model.config.name_or_path,
146
  "dimension": self.embedding_dim,
147
+ "max_seq_length": self.model.config.max_position_embeddings,
148
+ "normalize_embeddings": True # We're always normalizing
149
+ }
app/services/embedding_service.py CHANGED
@@ -1,64 +1,48 @@
 
1
  import logging
2
  import numpy as np
3
- import torch
4
  from typing import List, Dict, Any, Tuple
5
- from transformers import AutoModel, AutoTokenizer
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
  class EmbeddingService:
10
- """Service for handling document embeddings using Hugging Face Transformers."""
11
 
12
- def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
13
  """
14
  Initialize the embedding system.
15
 
16
  Args:
17
  model_name: Name of the model to use. Default is "sentence-transformers/all-MiniLM-L6-v2".
 
18
  """
19
  logger.info(f"Loading embedding model: {model_name}")
20
 
21
- # Disable gradients for inference
22
- torch.set_grad_enabled(False)
23
-
24
- # Load tokenizer and model
25
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
26
- self.model = AutoModel.from_pretrained(model_name)
27
-
28
-
29
- self.model.eval()
30
-
31
- # Get embedding dimension from model config
32
- self.embedding_dim = self.model.config.hidden_size
 
 
 
 
 
 
 
33
  logger.info(f"Embedding dimension: {self.embedding_dim}")
34
 
35
- def _mean_pooling(self, model_output, attention_mask):
36
- """
37
- Perform mean pooling on token embeddings.
38
-
39
- Args:
40
- model_output: Output from the transformer model
41
- attention_mask: Attention mask to avoid padding tokens
42
-
43
- Returns:
44
- Sentence embeddings
45
- """
46
- # First element of model_output contains token embeddings
47
- token_embeddings = model_output[0]
48
-
49
- # Expand attention mask from [batch_size, seq_length] to [batch_size, seq_length, hidden_size]
50
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
51
-
52
- # Sum token embeddings and divide by the expanded mask
53
- sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
54
- sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
55
-
56
- # Return mean-pooled embeddings
57
- return sum_embeddings / sum_mask
58
-
59
  def embed_documents(self, documents: List[Dict[str, Any]]) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
60
  """
61
- Embed a list of documents.
62
 
63
  Args:
64
  documents: List of document dictionaries.
@@ -69,43 +53,42 @@ class EmbeddingService:
69
  texts = [doc["text"] for doc in documents]
70
  logger.info(f"Embedding {len(texts)} documents...")
71
 
72
- # Process in batches to avoid OOM
73
- batch_size = 8
74
  all_embeddings = []
75
 
76
  for i in range(0, len(texts), batch_size):
77
  batch_texts = texts[i:i+batch_size]
78
 
79
- # Tokenize batch
80
- encoded_input = self.tokenizer(
81
- batch_texts,
82
- padding=True,
83
- truncation=True,
84
- max_length=512,
85
- return_tensors='pt'
86
- ).to("cpu")
87
-
88
- # Compute token embeddings
89
- with torch.no_grad():
90
- model_output = self.model(**encoded_input)
91
-
92
- # Apply mean pooling
93
- batch_embeddings = self._mean_pooling(model_output, encoded_input['attention_mask'])
94
-
95
- # Normalize embeddings
96
- batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
97
-
98
- # Convert to numpy and add to result
99
- all_embeddings.append(batch_embeddings.cpu().numpy())
100
 
101
  # Combine all batches
 
 
 
 
102
  embeddings = np.vstack(all_embeddings)
103
 
104
  return embeddings, documents
105
 
106
  def embed_query(self, query: str) -> np.ndarray:
107
  """
108
- Embed a search query.
109
 
110
  Args:
111
  query: The search query.
@@ -113,26 +96,19 @@ class EmbeddingService:
113
  Returns:
114
  Query embedding array.
115
  """
116
- # Tokenize query
117
- encoded_input = self.tokenizer(
118
- [query],
119
- padding=True,
120
- truncation=True,
121
- max_length=512,
122
- return_tensors='pt'
123
- ).to("cpu")
124
-
125
- # Compute token embeddings
126
- with torch.no_grad():
127
- model_output = self.model(**encoded_input)
128
-
129
- # Apply mean pooling
130
- embeddings = self._mean_pooling(model_output, encoded_input['attention_mask'])
131
-
132
- # Normalize embeddings
133
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
134
-
135
- return embeddings.cpu().numpy()
136
 
137
  def get_model_info(self) -> Dict[str, Any]:
138
  """
@@ -142,8 +118,8 @@ class EmbeddingService:
142
  Dictionary with model information.
143
  """
144
  return {
145
- "model_name": self.model.config.name_or_path,
146
  "dimension": self.embedding_dim,
147
- "max_seq_length": self.model.config.max_position_embeddings,
148
- "normalize_embeddings": True # We're always normalizing
149
  }
 
1
+ import os
2
  import logging
3
  import numpy as np
 
4
  from typing import List, Dict, Any, Tuple
5
+ from huggingface_hub import InferenceClient
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
  class EmbeddingService:
10
+ """Service for handling document embeddings using Hugging Face Inference API."""
11
 
12
+ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", api_key=None):
13
  """
14
  Initialize the embedding system.
15
 
16
  Args:
17
  model_name: Name of the model to use. Default is "sentence-transformers/all-MiniLM-L6-v2".
18
+ api_key: Hugging Face API key (will use env var if None)
19
  """
20
  logger.info(f"Loading embedding model: {model_name}")
21
 
22
+ # Set up API credentials
23
+ self.api_key = api_key or os.environ.get("HF_API_KEY")
24
+ self.client = InferenceClient(api_key=self.api_key)
25
+
26
+ # Store model name for future references
27
+ self.model_name = model_name
28
+
29
+ # Known embedding dimensions for common models
30
+ # Update this if you use a different model
31
+ embedding_dims = {
32
+ "sentence-transformers/all-MiniLM-L6-v2": 384,
33
+ "sentence-transformers/all-mpnet-base-v2": 768,
34
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": 384,
35
+ "sentence-transformers/paraphrase-MiniLM-L6-v2": 384,
36
+ "BAAI/bge-small-en-v1.5": 384,
37
+ "BAAI/bge-base-en-v1.5": 768
38
+ }
39
+
40
+ self.embedding_dim = embedding_dims.get(model_name, 384) # Default to 384 if unknown
41
  logger.info(f"Embedding dimension: {self.embedding_dim}")
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def embed_documents(self, documents: List[Dict[str, Any]]) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
44
  """
45
+ Embed a list of documents using the Hugging Face Inference API.
46
 
47
  Args:
48
  documents: List of document dictionaries.
 
53
  texts = [doc["text"] for doc in documents]
54
  logger.info(f"Embedding {len(texts)} documents...")
55
 
56
+ # Process in reasonably sized batches to optimize API calls
57
+ batch_size = 32 # Adjust based on your needs and API limits
58
  all_embeddings = []
59
 
60
  for i in range(0, len(texts), batch_size):
61
  batch_texts = texts[i:i+batch_size]
62
 
63
+ try:
64
+ # Call Inference API for feature-extraction (embeddings)
65
+ response = self.client.feature_extraction(
66
+ text=batch_texts,
67
+ model=self.model_name
68
+ )
69
+
70
+ # Convert response to numpy array and add to results
71
+ batch_embeddings = np.array(response)
72
+ all_embeddings.append(batch_embeddings)
73
+
74
+ logger.info(f"Successfully embedded batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
75
+ except Exception as e:
76
+ logger.error(f"Error embedding batch {i//batch_size + 1}: {str(e)}")
77
+ # Skip problematic batch or raise exception
78
+ raise # Re-raise for now to see errors in logs
 
 
 
 
 
79
 
80
  # Combine all batches
81
+ if not all_embeddings:
82
+ logger.warning("No embeddings were generated. Returning empty array.")
83
+ return np.array([]), documents
84
+
85
  embeddings = np.vstack(all_embeddings)
86
 
87
  return embeddings, documents
88
 
89
  def embed_query(self, query: str) -> np.ndarray:
90
  """
91
+ Embed a search query using the Hugging Face Inference API.
92
 
93
  Args:
94
  query: The search query.
 
96
  Returns:
97
  Query embedding array.
98
  """
99
+ try:
100
+ # Call Inference API for feature-extraction
101
+ response = self.client.feature_extraction(
102
+ text=[query],
103
+ model=self.model_name
104
+ )
105
+
106
+ # Convert to numpy array
107
+ embedding = np.array(response)
108
+ return embedding
109
+ except Exception as e:
110
+ logger.error(f"Error embedding query: {str(e)}")
111
+ raise # Re-raise for now to see errors in logs
 
 
 
 
 
 
 
112
 
113
  def get_model_info(self) -> Dict[str, Any]:
114
  """
 
118
  Dictionary with model information.
119
  """
120
  return {
121
+ "model_name": self.model_name,
122
  "dimension": self.embedding_dim,
123
+ "max_seq_length": 512, # Common default, may vary by model
124
+ "normalize_embeddings": True # Typically normalized in sentence-transformers models
125
  }