model_registry.py
- Pfad:
/var/www/scripts/pipeline/model_registry.py - Namespace: pipeline
- Zeilen: 189 | Größe: 5,883 Bytes
- Geändert: 2025-12-25 09:30:34 | Gescannt: 2025-12-31 10:22:15
Code Hygiene Score: 100
- Dependencies: 100 (25%)
- LOC: 100 (20%)
- Methods: 100 (20%)
- Secrets: 100 (15%)
- Classes: 100 (10%)
- Magic Numbers: 100 (10%)
Keine Issues gefunden.
Dependencies 3
- use typing.Optional
- use mysql.connector
- use config.get_db_password
Klassen 1
-
ModelRegistryclass Zeile 15
Funktionen 3
-
get_chat_model()Zeile 171 -
get_embedding_model()Zeile 181 -
get_model_id_for_ollama()Zeile 186
Code
"""
Model Registry - Single Source of Truth for AI Models.
Reads available models from ki_dev.ai_models database table.
Synchronizes with PHP ModelRegistry for consistent model availability.
"""
from typing import Optional
import mysql.connector
from config import get_db_password
class ModelRegistry:
"""Central registry for all AI models."""
_cache: list | None = None
_instance: Optional["ModelRegistry"] = None
def __init__(self):
self._conn = None
@classmethod
def get_instance(cls) -> "ModelRegistry":
"""Get singleton instance."""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def _get_connection(self):
"""Get or create database connection."""
if self._conn is None or not self._conn.is_connected():
self._conn = mysql.connector.connect(
host="localhost",
user="root",
password=get_db_password(),
database="ki_dev",
autocommit=True,
)
return self._conn
@classmethod
def clear_cache(cls):
"""Clear cached models."""
cls._cache = None
def get_chat_models(self) -> dict:
"""Get all available chat models.
Returns:
dict: {full_key: display_name}
"""
return self._get_models(is_chat=True)
def get_vision_models(self) -> dict:
"""Get all vision-capable models.
Returns:
dict: {full_key: display_name}
"""
return self._get_models(is_vision=True)
def get_embedding_models(self) -> dict:
"""Get all embedding models.
Returns:
dict: {full_key: display_name}
"""
return self._get_models(is_embedding=True)
def _get_models(
self,
is_chat: bool | None = None,
is_vision: bool | None = None,
is_embedding: bool | None = None,
provider: str | None = None,
) -> dict:
"""Get models with optional filters."""
all_models = self._load_all_models()
result = {}
for model in all_models:
if not model["is_available"]:
continue
if is_chat is not None and bool(model["is_chat"]) != is_chat:
continue
if is_vision is not None and bool(model["is_vision"]) != is_vision:
continue
if is_embedding is not None and bool(model["is_embedding"]) != is_embedding:
continue
if provider is not None and model["provider"] != provider:
continue
result[model["full_key"]] = model["display_name"]
return result
def get_model(self, full_key: str) -> dict | None:
"""Get a single model by full_key."""
all_models = self._load_all_models()
for model in all_models:
if model["full_key"] == full_key:
return model
return None
def get_label(self, full_key: str) -> str:
"""Get display label for a model."""
model = self.get_model(full_key)
return model["display_name"] if model else full_key
def is_valid(self, full_key: str) -> bool:
"""Check if model exists and is available."""
model = self.get_model(full_key)
return model is not None and model["is_available"]
def is_local(self, full_key: str) -> bool:
"""Check if model is local (Ollama)."""
return full_key.startswith("ollama:")
def get_default_chat_model(self) -> str:
"""Get default chat model (first available by priority)."""
chat_models = self.get_chat_models()
if chat_models:
return next(iter(chat_models.keys()))
return "ollama:mistral:latest"
def get_default_embedding_model(self) -> str:
"""Get default embedding model."""
embed_models = self.get_embedding_models()
if embed_models:
return next(iter(embed_models.keys()))
return "ollama:mxbai-embed-large:latest"
def get_ollama_model_id(self, full_key: str) -> str:
"""Extract Ollama model ID from full_key.
Example: 'ollama:gemma3:27b-it-qat' -> 'gemma3:27b-it-qat'
"""
if full_key.startswith("ollama:"):
return full_key[7:] # Remove 'ollama:' prefix
return full_key
def _load_all_models(self) -> list:
"""Load all models from database (with caching)."""
if ModelRegistry._cache is not None:
return ModelRegistry._cache
conn = self._get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"""
SELECT id, provider, model_id, display_name, full_key,
is_available, is_chat, is_embedding, is_vision,
context_length, parameters, priority
FROM ai_models
WHERE is_available = 1
ORDER BY priority ASC
"""
)
ModelRegistry._cache = cursor.fetchall()
cursor.close()
return ModelRegistry._cache
# Convenience functions for backward compatibility
def get_chat_model() -> str:
"""Get the default chat model from registry."""
registry = ModelRegistry.get_instance()
# Prefer local Ollama models for pipeline
ollama_models = registry._get_models(is_chat=True, provider="ollama")
if ollama_models:
return next(iter(ollama_models.keys()))
return registry.get_default_chat_model()
def get_embedding_model() -> str:
"""Get the default embedding model from registry."""
return ModelRegistry.get_instance().get_default_embedding_model()
def get_model_id_for_ollama(full_key: str) -> str:
"""Extract Ollama model ID from full_key."""
return ModelRegistry.get_instance().get_ollama_model_id(full_key)