model_registry.py

Code Hygiene Score: 100

Keine Issues gefunden.

Dependencies 3

Klassen 1

Funktionen 3

Code

"""
Model Registry - Single Source of Truth for AI Models.

Reads available models from ki_dev.ai_models database table.
Synchronizes with PHP ModelRegistry for consistent model availability.
"""

from typing import Optional

import mysql.connector

from config import get_db_password


class ModelRegistry:
    """Central registry for all AI models."""

    _cache: list | None = None
    _instance: Optional["ModelRegistry"] = None

    def __init__(self):
        self._conn = None

    @classmethod
    def get_instance(cls) -> "ModelRegistry":
        """Get singleton instance."""
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance

    def _get_connection(self):
        """Get or create database connection."""
        if self._conn is None or not self._conn.is_connected():
            self._conn = mysql.connector.connect(
                host="localhost",
                user="root",
                password=get_db_password(),
                database="ki_dev",
                autocommit=True,
            )
        return self._conn

    @classmethod
    def clear_cache(cls):
        """Clear cached models."""
        cls._cache = None

    def get_chat_models(self) -> dict:
        """Get all available chat models.

        Returns:
            dict: {full_key: display_name}
        """
        return self._get_models(is_chat=True)

    def get_vision_models(self) -> dict:
        """Get all vision-capable models.

        Returns:
            dict: {full_key: display_name}
        """
        return self._get_models(is_vision=True)

    def get_embedding_models(self) -> dict:
        """Get all embedding models.

        Returns:
            dict: {full_key: display_name}
        """
        return self._get_models(is_embedding=True)

    def _get_models(
        self,
        is_chat: bool | None = None,
        is_vision: bool | None = None,
        is_embedding: bool | None = None,
        provider: str | None = None,
    ) -> dict:
        """Get models with optional filters."""
        all_models = self._load_all_models()
        result = {}

        for model in all_models:
            if not model["is_available"]:
                continue
            if is_chat is not None and bool(model["is_chat"]) != is_chat:
                continue
            if is_vision is not None and bool(model["is_vision"]) != is_vision:
                continue
            if is_embedding is not None and bool(model["is_embedding"]) != is_embedding:
                continue
            if provider is not None and model["provider"] != provider:
                continue

            result[model["full_key"]] = model["display_name"]

        return result

    def get_model(self, full_key: str) -> dict | None:
        """Get a single model by full_key."""
        all_models = self._load_all_models()

        for model in all_models:
            if model["full_key"] == full_key:
                return model

        return None

    def get_label(self, full_key: str) -> str:
        """Get display label for a model."""
        model = self.get_model(full_key)
        return model["display_name"] if model else full_key

    def is_valid(self, full_key: str) -> bool:
        """Check if model exists and is available."""
        model = self.get_model(full_key)
        return model is not None and model["is_available"]

    def is_local(self, full_key: str) -> bool:
        """Check if model is local (Ollama)."""
        return full_key.startswith("ollama:")

    def get_default_chat_model(self) -> str:
        """Get default chat model (first available by priority)."""
        chat_models = self.get_chat_models()
        if chat_models:
            return next(iter(chat_models.keys()))
        return "ollama:mistral:latest"

    def get_default_embedding_model(self) -> str:
        """Get default embedding model."""
        embed_models = self.get_embedding_models()
        if embed_models:
            return next(iter(embed_models.keys()))
        return "ollama:mxbai-embed-large:latest"

    def get_ollama_model_id(self, full_key: str) -> str:
        """Extract Ollama model ID from full_key.

        Example: 'ollama:gemma3:27b-it-qat' -> 'gemma3:27b-it-qat'
        """
        if full_key.startswith("ollama:"):
            return full_key[7:]  # Remove 'ollama:' prefix
        return full_key

    def _load_all_models(self) -> list:
        """Load all models from database (with caching)."""
        if ModelRegistry._cache is not None:
            return ModelRegistry._cache

        conn = self._get_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute(
            """
            SELECT id, provider, model_id, display_name, full_key,
                   is_available, is_chat, is_embedding, is_vision,
                   context_length, parameters, priority
            FROM ai_models
            WHERE is_available = 1
            ORDER BY priority ASC
            """
        )

        ModelRegistry._cache = cursor.fetchall()
        cursor.close()
        return ModelRegistry._cache


# Convenience functions for backward compatibility
def get_chat_model() -> str:
    """Get the default chat model from registry."""
    registry = ModelRegistry.get_instance()
    # Prefer local Ollama models for pipeline
    ollama_models = registry._get_models(is_chat=True, provider="ollama")
    if ollama_models:
        return next(iter(ollama_models.keys()))
    return registry.get_default_chat_model()


def get_embedding_model() -> str:
    """Get the default embedding model from registry."""
    return ModelRegistry.get_instance().get_default_embedding_model()


def get_model_id_for_ollama(full_key: str) -> str:
    """Extract Ollama model ID from full_key."""
    return ModelRegistry.get_instance().get_ollama_model_id(full_key)
← Übersicht Graph