197 lines
7.0 KiB
Python
197 lines
7.0 KiB
Python
from sentence_transformers import SentenceTransformer
|
|
import numpy as np
|
|
from typing import List, Union, Dict, Any
|
|
import hashlib
|
|
import os
|
|
import pickle
|
|
from pathlib import Path
|
|
from app.core.config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingService:
|
|
def __init__(self, model_name: str = None):
|
|
self.model_name = model_name or settings.EMBEDDING_MODEL
|
|
self.model = None
|
|
self.cache_dir = Path("../data/embeddings_cache")
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
self._load_model()
|
|
|
|
def _load_model(self):
|
|
"""Load the SentenceTransformer model"""
|
|
try:
|
|
logger.info(f"Loading embedding model: {self.model_name}")
|
|
self.model = SentenceTransformer(self.model_name)
|
|
logger.info("Embedding model loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load embedding model: {e}")
|
|
raise Exception(f"Could not initialize embedding model: {e}")
|
|
|
|
def _get_cache_key(self, text: str) -> str:
|
|
"""Generate cache key for text"""
|
|
return hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
|
|
|
|
def _get_cached_embedding(self, cache_key: str) -> np.ndarray:
|
|
"""Get embedding from cache if available"""
|
|
cache_file = self.cache_dir / f"{cache_key}.pkl"
|
|
if cache_file.exists():
|
|
try:
|
|
with open(cache_file, 'rb') as f:
|
|
return pickle.load(f)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load cached embedding: {e}")
|
|
return None
|
|
|
|
def _cache_embedding(self, cache_key: str, embedding: np.ndarray):
|
|
"""Cache embedding for future use"""
|
|
cache_file = self.cache_dir / f"{cache_key}.pkl"
|
|
try:
|
|
with open(cache_file, 'wb') as f:
|
|
pickle.dump(embedding, f)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cache embedding: {e}")
|
|
|
|
def encode_text(self, text: str, use_cache: bool = True) -> np.ndarray:
|
|
"""Generate embedding for a single text"""
|
|
if not text or not text.strip():
|
|
return np.zeros(384) # Default embedding size for all-MiniLM-L6-v2
|
|
|
|
cache_key = self._get_cache_key(text)
|
|
|
|
# Check cache first
|
|
if use_cache:
|
|
cached_embedding = self._get_cached_embedding(cache_key)
|
|
if cached_embedding is not None:
|
|
return cached_embedding
|
|
|
|
try:
|
|
# Generate embedding
|
|
embedding = self.model.encode(text, convert_to_numpy=True)
|
|
|
|
# Cache for future use
|
|
if use_cache:
|
|
self._cache_embedding(cache_key, embedding)
|
|
|
|
return embedding
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate embedding: {e}")
|
|
return np.zeros(384)
|
|
|
|
def encode_texts(self, texts: List[str], use_cache: bool = True, batch_size: int = 32) -> List[np.ndarray]:
|
|
"""Generate embeddings for multiple texts"""
|
|
if not texts:
|
|
return []
|
|
|
|
embeddings = []
|
|
texts_to_encode = []
|
|
cache_keys = []
|
|
indices_to_encode = []
|
|
|
|
# Check cache for each text
|
|
for i, text in enumerate(texts):
|
|
if not text or not text.strip():
|
|
embeddings.append(np.zeros(384))
|
|
continue
|
|
|
|
cache_key = self._get_cache_key(text)
|
|
cache_keys.append(cache_key)
|
|
|
|
if use_cache:
|
|
cached_embedding = self._get_cached_embedding(cache_key)
|
|
if cached_embedding is not None:
|
|
embeddings.append(cached_embedding)
|
|
continue
|
|
|
|
# Need to encode this text
|
|
texts_to_encode.append(text)
|
|
indices_to_encode.append(i)
|
|
embeddings.append(None) # Placeholder
|
|
|
|
# Encode texts that weren't cached
|
|
if texts_to_encode:
|
|
try:
|
|
new_embeddings = self.model.encode(
|
|
texts_to_encode,
|
|
convert_to_numpy=True,
|
|
batch_size=batch_size
|
|
)
|
|
|
|
# Cache and place new embeddings
|
|
for idx, embedding in zip(indices_to_encode, new_embeddings):
|
|
embeddings[idx] = embedding
|
|
if use_cache:
|
|
self._cache_embedding(cache_keys[idx], embedding)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate batch embeddings: {e}")
|
|
# Fill with zeros for failed embeddings
|
|
for idx in indices_to_encode:
|
|
embeddings[idx] = np.zeros(384)
|
|
|
|
return embeddings
|
|
|
|
def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
|
"""Compute cosine similarity between two embeddings"""
|
|
try:
|
|
# Normalize embeddings
|
|
norm1 = np.linalg.norm(embedding1)
|
|
norm2 = np.linalg.norm(embedding2)
|
|
|
|
if norm1 == 0 or norm2 == 0:
|
|
return 0.0
|
|
|
|
# Cosine similarity
|
|
similarity = np.dot(embedding1, embedding2) / (norm1 * norm2)
|
|
return float(similarity)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to compute similarity: {e}")
|
|
return 0.0
|
|
|
|
def find_most_similar(
|
|
self,
|
|
query_embedding: np.ndarray,
|
|
candidate_embeddings: List[np.ndarray],
|
|
top_k: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""Find most similar embeddings to query"""
|
|
similarities = []
|
|
|
|
for i, candidate in enumerate(candidate_embeddings):
|
|
similarity = self.compute_similarity(query_embedding, candidate)
|
|
similarities.append({
|
|
'index': i,
|
|
'similarity': similarity
|
|
})
|
|
|
|
# Sort by similarity (descending)
|
|
similarities.sort(key=lambda x: x['similarity'], reverse=True)
|
|
|
|
return similarities[:top_k]
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
"""Get information about the loaded model"""
|
|
if not self.model:
|
|
return {}
|
|
|
|
return {
|
|
'model_name': self.model_name,
|
|
'max_sequence_length': getattr(self.model, 'max_seq_length', 'unknown'),
|
|
'embedding_dimension': self.model.get_sentence_embedding_dimension(),
|
|
}
|
|
|
|
def clear_cache(self):
|
|
"""Clear the embedding cache"""
|
|
try:
|
|
for cache_file in self.cache_dir.glob("*.pkl"):
|
|
cache_file.unlink()
|
|
logger.info("Embedding cache cleared")
|
|
except Exception as e:
|
|
logger.error(f"Failed to clear cache: {e}")
|
|
|
|
|
|
# Global instance
|
|
embedding_service = EmbeddingService() |