#!/usr/bin/env python3
"""
Copyright 2025 BlueCat Networks (USA) Inc. and its affiliates and licensors.
All Rights Reserved.

This file is part of mmWSClient, a Python module that provides a simple
wrapper for communicating with the Micetro Web Service.

Licensed under the MIT License. See LICENSE file in the project root
for full license information.
"""
import requests_toolbelt.sessions as sessions
from requests import Response
from urllib.parse import urlencode, quote, urljoin, urlsplit, urlunsplit
from typing import Any, Mapping, Optional
import os
import pprint
import logging

__all__ = ["Client", "ClientError"]
logger = logging.getLogger(__name__)


class ClientError(Exception):
    """Custom exception for errors in the library."""

    pass


def _get_env_or_value(
    env_key: str, value: Optional[str], required: bool = False
) -> Optional[str]:
    result = os.environ.get(env_key, value)
    if isinstance(result, str):
        result = result.strip()
    if required and not result:  # None or ""
        raise ValueError(f"Missing Client argument: {env_key}")
    return result


def _ensure_trailing_slash(p: str) -> str:
    return p if p.endswith("/") else p + "/"


def _encode_query(params: Mapping[str, Any]) -> str:
    # RFC 3986: spaces -> %20, keep ~ unescaped, expand sequences
    return urlencode(params, doseq=True, quote_via=quote, safe="~")


def _merge_query(url: str, extra_qs: str) -> str:
    if not extra_qs:
        return url
    p = urlsplit(url)
    merged = "&".join(x for x in (p.query, extra_qs) if x)
    return urlunsplit((p.scheme, p.netloc, p.path, merged, p.fragment))


def _join_path(path: str, *segments) -> str:
    """Join path + segments, percent-encoding each segment; no leading slash."""
    parts = []
    for seg in (path, *segments):
        if isinstance(seg, Mapping) or seg is None:
            continue  # dicts are not path; they go to params/body
        s = str(seg).strip("/")
        if s:
            # encode spaces, keep ~ and /
            parts.append(quote(s, safe="/~"))
    return "/".join(parts)


class Client:
    def __init__(
        self,
        server: Optional[str] = None,
        central: Optional[str] = None,
        username: Optional[str] = None,
        password: Optional[str] = None,
        token: Optional[str] = None,
        verbose: bool = False,
    ):
        """
        Client for interacting with the Micetro Web Service API.

        This class provides basic REST methods: `get()`, `post()`, and `delete()`.
        It also overrides `__getattr__` to dynamically map unknown method calls to
        remote commands, with arguments treated as request data.

        Parameters:
            server (str): URL to the Micetro Web Service. If only a hostname is provided,
                the service is assumed to be at http(s)://<host>/mmws.

            central (str): Optional Micetro Central server address. Typically not required
                if Central is running on the same server as the web service or its
                location is already configured.

            username (str): Authentication username.
            password (str): Authentication password.
            token (str): A session ID or token for authentication.

        Environment Variable Fallbacks:
            MM_SERVER, MM_CENTRAL, MM_USERNAME, MM_PASSWORD, MM_TOKEN, MM_VERBOSE

        Returns:
            A `Client` object that can be used to interact with Micetro and perform various
            operations. For a list of available functions, refer to:

            - https://docs.menandmice.com/en/latest/guides/user-manual/rest_api/
            - https://docs.menandmice.com/en/latest/guides/user-manual/json_rpc/
        """
        self.server = _get_env_or_value("MM_SERVER", server, required=True)
        self.central = _get_env_or_value("MM_CENTRAL", central)
        self._username = _get_env_or_value("MM_USERNAME", username)
        self._password = _get_env_or_value("MM_PASSWORD", password)
        self.token = _get_env_or_value("MM_TOKEN", token)
        self.verbose = (_get_env_or_value("MM_VERBOSE", "") == "True") or verbose

        if self.verbose:
            logger.setLevel(logging.DEBUG)

        if not (self.token or (self._username and self._password)):
            raise ValueError(f"Please provide either token or username/password")

        # if host only -> add /mmws
        if "/" not in self.server:
            self.server += "/mmws/"
        # if no schema assume http
        if "://" not in server:
            self.server = "http://" + self.server
        self.server = _ensure_trailing_slash(self.server)
        self._create_session()

    def _create_session(self):
        # Base at .../mmws/api/v2/
        self.session = sessions.BaseUrlSession(base_url=urljoin(self.server, "api/v2/"))

        # Default query params (appended to every request)
        params = {}
        if self.central:
            params["server"] = self.central
        self.session.params = params

        # Default headers
        headers = {"User-Agent": "mmWSClient/1.1.0", "Accept": "application/json"}

        if not self.token:
            # We need a usable token, we can use legal session id as a token
            data = {"loginName": self._username, "password": self._password}
            if self.central:
                data["server"] = self.central
            resp = self.session.post("micetro/sessions", headers=headers, json=data)
            if not resp.ok:
                raise ValueError("Invalid username or password")
            self.token = resp.json()["result"]["session"]
            if not self.token:
                raise ClientError(f"micetro/session returned an empty session")

        # Authorization header
        headers["Authorization"] = "Bearer " + self.token
        self.session.headers.update(headers)
        return

    # ---------- General reques ----------
    def request(
        self,
        method: str,
        path: str,
        *,
        params: Optional[Mapping[str, Any]] = None,
        headers: Optional[Mapping[str, str]] = None,
        json: Any = None,
        data: Any = None,
        **kwargs: Any,
    ):
        url = self.session.create_url(path)  # works with relative or absolute
        if params:
            url = _merge_query(url, _encode_query(params))  # spaces -> %20
        self._log_request(method, url, headers, json, data)
        resp = self.session.request(
            method, url, headers=headers, json=json, data=data, **kwargs
        )
        self._log_response(resp)

        if not resp.ok:
            self._handle_error(resp)

        # Some API calls return nothing, no json.
        body = resp.content
        if not body or body.strip() == b"":
            return {}

        payload = resp.json()

        # Note: Most responses contain a single json member named result.
        # For those we simply return what is inside result. We have some
        # exceptions to this i.e. GetReportFile() which returns raw json,
        # no result. For those we just return everything.
        if "result" in payload:
            return payload["result"]
        else:
            return payload

        # ---------- Public REST methods ----------

    def get(self, path: str, *segments, **kwargs):
        """
        GET("users", 2, limit=10, offset=1)
        -> GET users/2?limit=10&offset=1
        Dict positional args are folded into params too.
        """
        # Build path
        rel = _join_path(path, *segments)
        # Collect params: dict positional + kwargs
        params = {}
        for seg in segments:
            if isinstance(seg, Mapping):
                params.update(seg)
        params.update(kwargs)
        return self.request("GET", rel, params=params)

    def post(
        self,
        path: str,
        *segments,
        params: Optional[Mapping[str, Any]] = None,
        headers: Optional[Mapping[str, str]] = None,
        **kwargs,
    ):
        """
        POST("users", 2, "key", {"description": "here"}, name="my key")
        -> POST users/2/key
            body: {"name": "my key", "description": "here"}
        Dict positional args go to JSON body; kwargs too.
        """
        # Build path
        rel = _join_path(path, *segments)

        # Body = merge dict positional args + kwargs
        body = {}
        for seg in segments:
            if isinstance(seg, Mapping):
                body.update(seg)
        body.update(kwargs)

        return self.request(
            "POST",
            rel,
            params=params,  # optional querystring on POST
            headers=headers,
            json=(body if body else None),  # let requests omit Content-Type if empty
        )

    def delete(self, path: str, *segments, **kwargs):
        """
        DELETE("users", 2, saveComment="a save comment") -> DELETE users/2?saveComment="a save comment"
        Dict positional args are folded into params too.
        """
        rel = _join_path(path, *segments)
        params = {}
        for seg in segments:
            if isinstance(seg, Mapping):
                params.update(seg)
        params.update(kwargs)
        return self.request("DELETE", rel, params=params)

    # ---------- Dynamic mapping ----------
    # Dynamically map unknown methods to POST api/command/<method>
    def __getattr__(self, method_name: str):
        def dynamic_method(*args, **kwargs):
            # Merge positional dicts into kwargs (back-compat)
            for arg in args:
                if isinstance(arg, dict):
                    kwargs.update(arg)
            path = f"command/{method_name}"
            return self.post(path, **kwargs)

        return dynamic_method

    def _log_request(self, method, url, headers, json, data):
        extras = {}
        if json is not None:
            extras["json"] = json
        if data is not None:
            extras["data"] = "<bytes>" if isinstance(data, (bytes, bytearray)) else data
        logger.debug(
            f"Request: {method} {url}"
            + (f"\n > {pprint.pformat(extras)}" if extras else "")
        )

    def _log_response(self, response):
        ctype = response.headers.get("Content-Type", "")
        if ("json" in ctype or "text" in ctype) and response.text:
            body = f" < {response.text}\n"
        elif response.content:
            body = f" < <{len(response.content)} bytes>\n"
        else:
            body = ""
        logger.debug(f"Response: {response.status_code}\n{body}")

    def _handle_error(self, response: Response):
        try:
            api_error = response.json()
            error_info = api_error.get("error", {})
            code = error_info.get("code")
            message = error_info.get("message", "Unknown error")
            raise ClientError(f"Server responded with error {code}: {message}")
        except ValueError:
            raise ClientError(f"Error {response.status_code} occurred: {response.text}")
