"""
ZhenQuant Python SDK v1.0.0
Transitional downloadable SDK reference aligned to the current public contract.

Install:
    pip install requests

Usage:
    from zhenquant_sdk import ZhenQuantClient

    client = ZhenQuantClient(api_key="zq_live_xxx")
    result = client.factors.calculate(
        factor_code="RSI",
        symbols=["AAPL"],
        timeframe="1d",
        period=14,
    )
"""

from typing import Any, Dict, List, Optional

import requests


class ZhenQuantError(Exception):
    """Base SDK error."""


class AuthenticationError(ZhenQuantError):
    """Authentication failed."""


class RateLimitError(ZhenQuantError):
    """Rate limit exceeded."""

    def __init__(self, message: str, retry_after: Optional[int] = None):
        super().__init__(message)
        self.retry_after = retry_after


class ValidationError(ZhenQuantError):
    """Request payload validation failed."""


class FactorsAPI:
    """Factor execution wrapper for the unified research endpoint."""

    def __init__(self, client: "ZhenQuantClient"):
        self._client = client

    def calculate(
        self,
        factor_code: str,
        symbols: List[str],
        timeframe: str = "1d",
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        **params: Any,
    ) -> Dict[str, Any]:
        if not factor_code:
          raise ValidationError("factor_code is required")
        if not symbols:
          raise ValidationError("symbols must contain at least one item")

        payload: Dict[str, Any] = {
            "factor": factor_code.upper(),
            "symbols": [symbol.upper() for symbol in symbols],
            "params": params,
            "timeframe": timeframe,
        }

        if start_date:
            payload["start_date"] = start_date
        if end_date:
            payload["end_date"] = end_date

        return self._client._request(
            method="POST",
            endpoint="/v1/factors/calculate",
            data=payload,
        )

    def ma(
        self,
        symbols: List[str],
        period: int = 20,
        timeframe: str = "1d",
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
    ) -> Dict[str, Any]:
        return self.calculate(
            factor_code="MA",
            symbols=symbols,
            timeframe=timeframe,
            start_date=start_date,
            end_date=end_date,
            period=period,
        )

    def rsi(
        self,
        symbols: List[str],
        period: int = 14,
        timeframe: str = "1d",
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
    ) -> Dict[str, Any]:
        return self.calculate(
            factor_code="RSI",
            symbols=symbols,
            timeframe=timeframe,
            start_date=start_date,
            end_date=end_date,
            period=period,
        )

    def macd(
        self,
        symbols: List[str],
        fast_period: int = 12,
        slow_period: int = 26,
        signal_period: int = 9,
        timeframe: str = "1d",
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
    ) -> Dict[str, Any]:
        return self.calculate(
            factor_code="MACD",
            symbols=symbols,
            timeframe=timeframe,
            start_date=start_date,
            end_date=end_date,
            fast_period=fast_period,
            slow_period=slow_period,
            signal_period=signal_period,
        )


class ZhenQuantClient:
    """Thin client for the public factor contract."""

    def __init__(
        self,
        api_key: str,
        base_url: str = "https://www.zhenquant.hk",
        timeout: int = 30,
    ):
        if not api_key or not api_key.startswith("zq_live_"):
            raise ValueError("Invalid API key format. Must start with 'zq_live_'")

        self.api_key = api_key
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self.session = requests.Session()
        self.session.headers.update(
            {
                "X-API-Key": self.api_key,
                "Content-Type": "application/json",
                "User-Agent": "ZhenQuant-Python-SDK/1.0.0",
            }
        )
        self.factors = FactorsAPI(self)

    def _request(
        self,
        method: str,
        endpoint: str,
        data: Optional[Dict[str, Any]] = None,
        params: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        url = f"{self.base_url}{endpoint}"

        try:
            response = self.session.request(
                method=method,
                url=url,
                json=data,
                params=params,
                timeout=self.timeout,
            )
        except requests.Timeout as exc:
            raise ZhenQuantError(f"Request timeout after {self.timeout} seconds") from exc
        except requests.RequestException as exc:
            raise ZhenQuantError(f"Request failed: {exc}") from exc

        if response.status_code in (401, 403):
            raise AuthenticationError("Invalid API key or authentication failed")

        if response.status_code == 429:
            retry_after = response.headers.get("Retry-After")
            raise RateLimitError(
                "Rate limit exceeded. Please try again later.",
                retry_after=int(retry_after) if retry_after and retry_after.isdigit() else None,
            )

        content_type = response.headers.get("content-type", "")
        payload: Any
        if "application/json" in content_type:
            payload = response.json()
        else:
            payload = response.text

        if response.status_code == 400:
            if isinstance(payload, dict):
                raise ValidationError(payload.get("detail") or payload.get("message") or "Invalid parameters")
            raise ValidationError(str(payload))

        if response.status_code >= 500:
            raise ZhenQuantError(f"Server error: {response.status_code}")

        if not response.ok:
            raise ZhenQuantError(f"HTTP {response.status_code}: {payload}")

        return payload

    def calculate_ma(
        self,
        symbols: List[str],
        period: int = 20,
        timeframe: str = "1d",
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
    ) -> Dict[str, Any]:
        return self.factors.ma(
            symbols=symbols,
            period=period,
            timeframe=timeframe,
            start_date=start_date,
            end_date=end_date,
        )

    def calculate_rsi(
        self,
        symbols: List[str],
        period: int = 14,
        timeframe: str = "1d",
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
    ) -> Dict[str, Any]:
        return self.factors.rsi(
            symbols=symbols,
            period=period,
            timeframe=timeframe,
            start_date=start_date,
            end_date=end_date,
        )

    def calculate_macd(
        self,
        symbols: List[str],
        fast_period: int = 12,
        slow_period: int = 26,
        signal_period: int = 9,
        timeframe: str = "1d",
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
    ) -> Dict[str, Any]:
        return self.factors.macd(
            symbols=symbols,
            fast_period=fast_period,
            slow_period=slow_period,
            signal_period=signal_period,
            timeframe=timeframe,
            start_date=start_date,
            end_date=end_date,
        )

    def calculate_batch(
        self,
        symbols: List[str],
        factor_requests: List[Dict[str, Any]],
        timeframe: str = "1d",
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
    ) -> Dict[str, Dict[str, Any]]:
        results: Dict[str, Dict[str, Any]] = {}

        for request in factor_requests:
            factor_code = str(request.get("factor_code") or request.get("code") or "").upper()
            params = dict(request.get("params") or {})

            for key in ("period", "fast_period", "slow_period", "signal_period"):
                if key in request and key not in params:
                    params[key] = request[key]

            if not factor_code:
                results["unknown"] = {"error": "factor_code is required"}
                continue

            try:
                results[factor_code] = self.factors.calculate(
                    factor_code=factor_code,
                    symbols=symbols,
                    timeframe=timeframe,
                    start_date=start_date,
                    end_date=end_date,
                    **params,
                )
            except ZhenQuantError as exc:
                results[factor_code] = {"error": str(exc)}

        return results

    def health_check(self) -> Dict[str, Any]:
        return self._request(method="GET", endpoint="/api/health")

    def close(self):
        self.session.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


if __name__ == "__main__":
    API_KEY = "zq_live_your_api_key_here"

    with ZhenQuantClient(api_key=API_KEY) as client:
        health = client.health_check()
        print(f"API status: {health}")

        result = client.factors.calculate(
            factor_code="MACD",
            symbols=["AAPL", "TSLA", "00700.HK"],
            timeframe="1d",
            fast_period=12,
            slow_period=26,
            signal_period=9,
        )
        print("Unified factor result:")
        print(result)
