from sentence_transformers import SentenceTransformer import numpy as np from typing import List, Union, Dict, Any import hashlib import os import pickle from pathlib import Path from app.core.config import settings import logging logger = logging.getLogger(__name__) class EmbeddingService: def __init__(self, model_name: str = None): self.model_name = model_name or settings.EMBEDDING_MODEL self.model = None self.cache_dir = Path("../data/embeddings_cache") self.cache_dir.mkdir(parents=True, exist_ok=True) self._load_model() def _load_model(self): """Load the SentenceTransformer model""" try: logger.info(f"Loading embedding model: {self.model_name}") self.model = SentenceTransformer(self.model_name) logger.info("Embedding model loaded successfully") except Exception as e: logger.error(f"Failed to load embedding model: {e}") raise Exception(f"Could not initialize embedding model: {e}") def _get_cache_key(self, text: str) -> str: """Generate cache key for text""" return hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest() def _get_cached_embedding(self, cache_key: str) -> np.ndarray: """Get embedding from cache if available""" cache_file = self.cache_dir / f"{cache_key}.pkl" if cache_file.exists(): try: with open(cache_file, 'rb') as f: return pickle.load(f) except Exception as e: logger.warning(f"Failed to load cached embedding: {e}") return None def _cache_embedding(self, cache_key: str, embedding: np.ndarray): """Cache embedding for future use""" cache_file = self.cache_dir / f"{cache_key}.pkl" try: with open(cache_file, 'wb') as f: pickle.dump(embedding, f) except Exception as e: logger.warning(f"Failed to cache embedding: {e}") def encode_text(self, text: str, use_cache: bool = True) -> np.ndarray: """Generate embedding for a single text""" if not text or not text.strip(): return np.zeros(384) # Default embedding size for all-MiniLM-L6-v2 cache_key = self._get_cache_key(text) # Check cache first if use_cache: cached_embedding = self._get_cached_embedding(cache_key) if cached_embedding is not None: return cached_embedding try: # Generate embedding embedding = self.model.encode(text, convert_to_numpy=True) # Cache for future use if use_cache: self._cache_embedding(cache_key, embedding) return embedding except Exception as e: logger.error(f"Failed to generate embedding: {e}") return np.zeros(384) def encode_texts(self, texts: List[str], use_cache: bool = True, batch_size: int = 32) -> List[np.ndarray]: """Generate embeddings for multiple texts""" if not texts: return [] embeddings = [] texts_to_encode = [] cache_keys = [] indices_to_encode = [] # Check cache for each text for i, text in enumerate(texts): if not text or not text.strip(): embeddings.append(np.zeros(384)) continue cache_key = self._get_cache_key(text) cache_keys.append(cache_key) if use_cache: cached_embedding = self._get_cached_embedding(cache_key) if cached_embedding is not None: embeddings.append(cached_embedding) continue # Need to encode this text texts_to_encode.append(text) indices_to_encode.append(i) embeddings.append(None) # Placeholder # Encode texts that weren't cached if texts_to_encode: try: new_embeddings = self.model.encode( texts_to_encode, convert_to_numpy=True, batch_size=batch_size ) # Cache and place new embeddings for idx, embedding in zip(indices_to_encode, new_embeddings): embeddings[idx] = embedding if use_cache: self._cache_embedding(cache_keys[idx], embedding) except Exception as e: logger.error(f"Failed to generate batch embeddings: {e}") # Fill with zeros for failed embeddings for idx in indices_to_encode: embeddings[idx] = np.zeros(384) return embeddings def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float: """Compute cosine similarity between two embeddings""" try: # Normalize embeddings norm1 = np.linalg.norm(embedding1) norm2 = np.linalg.norm(embedding2) if norm1 == 0 or norm2 == 0: return 0.0 # Cosine similarity similarity = np.dot(embedding1, embedding2) / (norm1 * norm2) return float(similarity) except Exception as e: logger.error(f"Failed to compute similarity: {e}") return 0.0 def find_most_similar( self, query_embedding: np.ndarray, candidate_embeddings: List[np.ndarray], top_k: int = 5 ) -> List[Dict[str, Any]]: """Find most similar embeddings to query""" similarities = [] for i, candidate in enumerate(candidate_embeddings): similarity = self.compute_similarity(query_embedding, candidate) similarities.append({ 'index': i, 'similarity': similarity }) # Sort by similarity (descending) similarities.sort(key=lambda x: x['similarity'], reverse=True) return similarities[:top_k] def get_model_info(self) -> Dict[str, Any]: """Get information about the loaded model""" if not self.model: return {} return { 'model_name': self.model_name, 'max_sequence_length': getattr(self.model, 'max_seq_length', 'unknown'), 'embedding_dimension': self.model.get_sentence_embedding_dimension(), } def clear_cache(self): """Clear the embedding cache""" try: for cache_file in self.cache_dir.glob("*.pkl"): cache_file.unlink() logger.info("Embedding cache cleared") except Exception as e: logger.error(f"Failed to clear cache: {e}") # Global instance embedding_service = EmbeddingService()