update_tool.py

Code Hygiene Score: 88

Keine Issues gefunden.

Dependencies 10

Funktionen 2

Code

"""Update Tool - Aktualisiert Datensaetze."""

import re
import time
from datetime import datetime

from shared.constants import LOG_ENTRY_MAX_LENGTH, LOG_QUERY_MAX_LENGTH, MS_PER_SECOND

from config import Config
from infrastructure.db_connection import DatabaseConnection
from shared.domain import LogEntry
from shared.infrastructure import get_logger


def _validate_identifier(name: str) -> bool:
    """Validiert Tabellen-/Spaltennamen gegen SQL-Injection."""
    return bool(re.match(r"^[a-zA-Z0-9_]+$", name))


def register_update_tool(mcp) -> None:
    """Registriert db_update Tool."""
    logger = get_logger("mcp-db", Config)

    @mcp.tool()
    def db_update(
        table: str,
        data: dict,
        where: dict,
        database: str = "ki_dev",
    ) -> dict:
        """
        Aktualisiert Datensaetze.

        Args:
            table: Zieltabelle
            data: Dict mit Spalte:Wert Paaren (SET-Klausel)
            where: Dict mit Spalte:Wert Paaren (WHERE-Klausel) - PFLICHT!
            database: Zieldatenbank (ki_dev oder ki_content)

        Returns:
            Dict mit status, affected_rows, error
        """
        start = time.time()

        # Validierung: Tabellenname
        if not _validate_identifier(table):
            return {
                "status": "denied",
                "error": "Invalid table name.",
            }

        # Validierung: Datenbank
        if database not in Config.ALLOWED_DATABASES:
            return {
                "status": "denied",
                "error": f"Database '{database}' not allowed.",
            }

        # Validierung: Data nicht leer
        if not data:
            return {
                "status": "denied",
                "error": "Data dict must not be empty.",
            }

        # KRITISCH: WHERE ist Pflicht!
        if not where:
            return {
                "status": "denied",
                "error": "WHERE clause is required. UPDATE without WHERE is forbidden.",
            }

        # Validierung: Spaltennamen in data
        for col in data:
            if not _validate_identifier(col):
                return {
                    "status": "denied",
                    "error": f"Invalid column name in data: {col}",
                }

        # Validierung: Spaltennamen in where
        for col in where:
            if not _validate_identifier(col):
                return {
                    "status": "denied",
                    "error": f"Invalid column name in where: {col}",
                }

        try:
            with DatabaseConnection.get_connection(database) as conn:
                cursor = conn.cursor(buffered=True)

                # Build SET clause
                set_parts = [f"`{col}` = %s" for col in data]
                set_clause = ", ".join(set_parts)

                # Build WHERE clause
                where_parts = [f"`{col}` = %s" for col in where]
                where_clause = " AND ".join(where_parts)

                # Combine values: SET values first, then WHERE values
                values = tuple(data.values()) + tuple(where.values())

                query = f"UPDATE `{table}` SET {set_clause} WHERE {where_clause}"
                cursor.execute(query, values)
                affected_rows = cursor.rowcount
                conn.commit()
                cursor.close()

                duration = int((time.time() - start) * MS_PER_SECOND)

                # Log
                try:
                    logger.log(
                        LogEntry(
                            timestamp=datetime.now(),
                            client_name="mcp-db",
                            tool_name="db_update",
                            request=f"UPDATE {table} SET ... WHERE ...",
                            status="success",
                            duration_ms=duration,
                        )
                    )
                except Exception:
                    pass

                return {
                    "status": "success",
                    "table": table,
                    "affected_rows": affected_rows,
                    "execution_ms": duration,
                }

        except Exception as e:
            duration = int((time.time() - start) * MS_PER_SECOND)

            try:
                logger.log(
                    LogEntry(
                        timestamp=datetime.now(),
                        client_name="mcp-db",
                        tool_name="db_update",
                        request=f"UPDATE {table}",
                        status="error",
                        duration_ms=duration,
                        error_message=str(e)[:LOG_ENTRY_MAX_LENGTH],
                    )
                )
            except Exception:
                pass

            return {
                "status": "error",
                "error": str(e)[:LOG_QUERY_MAX_LENGTH],
                "execution_ms": duration,
            }
← Übersicht