146 lines
5.0 KiB
Python
146 lines
5.0 KiB
Python
import httpx
|
|
import json
|
|
from typing import Dict, List, Any, Optional, AsyncGenerator
|
|
from app.core.config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OllamaClient:
|
|
def __init__(self, base_url: str = None, model: str = None):
|
|
self.base_url = base_url or settings.OLLAMA_BASE_URL
|
|
self.model = model or settings.DEFAULT_LLM_MODEL
|
|
self.client = httpx.AsyncClient(timeout=60.0)
|
|
|
|
async def chat(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
system_prompt: Optional[str] = None,
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 2000
|
|
) -> str:
|
|
"""Send chat messages to Ollama and get response"""
|
|
try:
|
|
# Format messages for Ollama
|
|
if system_prompt:
|
|
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"options": {
|
|
"temperature": temperature,
|
|
"num_predict": max_tokens
|
|
},
|
|
"stream": False
|
|
}
|
|
|
|
response = await self.client.post(
|
|
f"{self.base_url}/api/chat",
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
return result.get("message", {}).get("content", "")
|
|
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Request error communicating with Ollama: {e}")
|
|
raise Exception(f"Failed to communicate with local LLM: {e}")
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"HTTP error from Ollama: {e}")
|
|
raise Exception(f"LLM service error: {e}")
|
|
|
|
async def chat_stream(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
system_prompt: Optional[str] = None,
|
|
temperature: float = 0.7
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Stream chat response from Ollama"""
|
|
try:
|
|
if system_prompt:
|
|
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"options": {
|
|
"temperature": temperature
|
|
},
|
|
"stream": True
|
|
}
|
|
|
|
async with self.client.stream(
|
|
"POST",
|
|
f"{self.base_url}/api/chat",
|
|
json=payload
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
if line:
|
|
try:
|
|
data = json.loads(line)
|
|
if "message" in data and "content" in data["message"]:
|
|
yield data["message"]["content"]
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Request error streaming from Ollama: {e}")
|
|
raise Exception(f"Failed to stream from local LLM: {e}")
|
|
|
|
async def generate_embedding(self, text: str) -> List[float]:
|
|
"""Generate embeddings using Ollama (if supported by model)"""
|
|
try:
|
|
payload = {
|
|
"model": "nomic-embed-text", # Use embedding-specific model
|
|
"prompt": text
|
|
}
|
|
|
|
response = await self.client.post(
|
|
f"{self.base_url}/api/embeddings",
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
return result.get("embedding", [])
|
|
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Request error getting embeddings from Ollama: {e}")
|
|
return []
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"HTTP error getting embeddings from Ollama: {e}")
|
|
return []
|
|
|
|
async def check_health(self) -> bool:
|
|
"""Check if Ollama service is available"""
|
|
try:
|
|
response = await self.client.get(f"{self.base_url}/api/tags")
|
|
return response.status_code == 200
|
|
except:
|
|
return False
|
|
|
|
async def list_models(self) -> List[str]:
|
|
"""List available models in Ollama"""
|
|
try:
|
|
response = await self.client.get(f"{self.base_url}/api/tags")
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
models = result.get("models", [])
|
|
return [model["name"] for model in models]
|
|
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Request error listing models from Ollama: {e}")
|
|
return []
|
|
|
|
async def close(self):
|
|
"""Close the HTTP client"""
|
|
await self.client.aclose()
|
|
|
|
|
|
# Global instance
|
|
ollama_client = OllamaClient() |