apersona/backend/ai_core/embeddings/embedding_service.py

197 lines
7.0 KiB
Python

from sentence_transformers import SentenceTransformer
import numpy as np
from typing import List, Union, Dict, Any
import hashlib
import os
import pickle
from pathlib import Path
from app.core.config import settings
import logging
logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(self, model_name: str = None):
self.model_name = model_name or settings.EMBEDDING_MODEL
self.model = None
self.cache_dir = Path("../data/embeddings_cache")
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._load_model()
def _load_model(self):
"""Load the SentenceTransformer model"""
try:
logger.info(f"Loading embedding model: {self.model_name}")
self.model = SentenceTransformer(self.model_name)
logger.info("Embedding model loaded successfully")
except Exception as e:
logger.error(f"Failed to load embedding model: {e}")
raise Exception(f"Could not initialize embedding model: {e}")
def _get_cache_key(self, text: str) -> str:
"""Generate cache key for text"""
return hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
def _get_cached_embedding(self, cache_key: str) -> np.ndarray:
"""Get embedding from cache if available"""
cache_file = self.cache_dir / f"{cache_key}.pkl"
if cache_file.exists():
try:
with open(cache_file, 'rb') as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"Failed to load cached embedding: {e}")
return None
def _cache_embedding(self, cache_key: str, embedding: np.ndarray):
"""Cache embedding for future use"""
cache_file = self.cache_dir / f"{cache_key}.pkl"
try:
with open(cache_file, 'wb') as f:
pickle.dump(embedding, f)
except Exception as e:
logger.warning(f"Failed to cache embedding: {e}")
def encode_text(self, text: str, use_cache: bool = True) -> np.ndarray:
"""Generate embedding for a single text"""
if not text or not text.strip():
return np.zeros(384) # Default embedding size for all-MiniLM-L6-v2
cache_key = self._get_cache_key(text)
# Check cache first
if use_cache:
cached_embedding = self._get_cached_embedding(cache_key)
if cached_embedding is not None:
return cached_embedding
try:
# Generate embedding
embedding = self.model.encode(text, convert_to_numpy=True)
# Cache for future use
if use_cache:
self._cache_embedding(cache_key, embedding)
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return np.zeros(384)
def encode_texts(self, texts: List[str], use_cache: bool = True, batch_size: int = 32) -> List[np.ndarray]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
embeddings = []
texts_to_encode = []
cache_keys = []
indices_to_encode = []
# Check cache for each text
for i, text in enumerate(texts):
if not text or not text.strip():
embeddings.append(np.zeros(384))
continue
cache_key = self._get_cache_key(text)
cache_keys.append(cache_key)
if use_cache:
cached_embedding = self._get_cached_embedding(cache_key)
if cached_embedding is not None:
embeddings.append(cached_embedding)
continue
# Need to encode this text
texts_to_encode.append(text)
indices_to_encode.append(i)
embeddings.append(None) # Placeholder
# Encode texts that weren't cached
if texts_to_encode:
try:
new_embeddings = self.model.encode(
texts_to_encode,
convert_to_numpy=True,
batch_size=batch_size
)
# Cache and place new embeddings
for idx, embedding in zip(indices_to_encode, new_embeddings):
embeddings[idx] = embedding
if use_cache:
self._cache_embedding(cache_keys[idx], embedding)
except Exception as e:
logger.error(f"Failed to generate batch embeddings: {e}")
# Fill with zeros for failed embeddings
for idx in indices_to_encode:
embeddings[idx] = np.zeros(384)
return embeddings
def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""Compute cosine similarity between two embeddings"""
try:
# Normalize embeddings
norm1 = np.linalg.norm(embedding1)
norm2 = np.linalg.norm(embedding2)
if norm1 == 0 or norm2 == 0:
return 0.0
# Cosine similarity
similarity = np.dot(embedding1, embedding2) / (norm1 * norm2)
return float(similarity)
except Exception as e:
logger.error(f"Failed to compute similarity: {e}")
return 0.0
def find_most_similar(
self,
query_embedding: np.ndarray,
candidate_embeddings: List[np.ndarray],
top_k: int = 5
) -> List[Dict[str, Any]]:
"""Find most similar embeddings to query"""
similarities = []
for i, candidate in enumerate(candidate_embeddings):
similarity = self.compute_similarity(query_embedding, candidate)
similarities.append({
'index': i,
'similarity': similarity
})
# Sort by similarity (descending)
similarities.sort(key=lambda x: x['similarity'], reverse=True)
return similarities[:top_k]
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model"""
if not self.model:
return {}
return {
'model_name': self.model_name,
'max_sequence_length': getattr(self.model, 'max_seq_length', 'unknown'),
'embedding_dimension': self.model.get_sentence_embedding_dimension(),
}
def clear_cache(self):
"""Clear the embedding cache"""
try:
for cache_file in self.cache_dir.glob("*.pkl"):
cache_file.unlink()
logger.info("Embedding cache cleared")
except Exception as e:
logger.error(f"Failed to clear cache: {e}")
# Global instance
embedding_service = EmbeddingService()