#!/usr/bin/env python3
"""
Test Auth Flow - Complete OAuth authentication and Avi A2A server connection test.

This script contains all OAuth authentication components and tests the complete
auth flow including:
1. OAuth token acquisition with DCR (Dynamic Client Registration)
2. A2A server connection

All OAuth-related classes and functions are contained in this script.
"""

import asyncio
import ssl
import webbrowser
import tempfile
import os
import logging
import secrets
import base64
import hashlib
import sys
import contextlib
import argparse
from pathlib import Path
from typing import Optional, Dict, Any
from authlib.integrations.httpx_client import AsyncOAuth2Client
import httpx
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse
from starlette.routing import Route
import uvicorn
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
import datetime
import ipaddress
from uuid import uuid4
import json

from a2a.client import A2ACardResolver, A2AClient
from a2a.types import (
    AgentCard,
    MessageSendParams,
    SendMessageRequest,
    SendStreamingMessageRequest,
)
from a2a.utils.constants import (
    AGENT_CARD_WELL_KNOWN_PATH,
    EXTENDED_AGENT_CARD_PATH,
)

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configuration

DEFAULT_PRODUCT = "avatax"
#
BASE_SERVER_URL = "https://mcp.avalara.com"
AVI_A2A_URL="https://avi.avalara.com/a2a/avi"
CALLBACK_PORT = 59369

def get_server_url(product: str = DEFAULT_PRODUCT) -> str:
    """
    Get the full server URL for a given product.

    Args:
        product: Product name to append to base URL

    Returns:
        Full server URL
    """
    return f"{BASE_SERVER_URL}/{product}"

# For backward compatibility
SERVER_URL = get_server_url()


@contextlib.contextmanager
def suppress_stderr():
    """Context manager to suppress stderr output"""
    with open(os.devnull, 'w') as devnull:
        old_stderr = sys.stderr
        sys.stderr = devnull
        try:
            yield
        finally:
            sys.stderr = old_stderr


class LoggingTransport(httpx.AsyncHTTPTransport):
    """Custom async transport to log OAuth requests"""

    def __init__(self, verbose: bool = False):
        super().__init__()
        self.verbose = verbose
        self.logger = logging.getLogger(f"{__name__}.LoggingTransport")

    async def handle_async_request(self, request):
        url_str = str(request.url)
        if self.verbose and ('/connect/token' in url_str or '/connect/authorize' in url_str):
            self.logger.debug("=== Raw OAuth Request ===")
            self.logger.debug(f"Method: {request.method}")
            self.logger.debug(f"URL: {request.url}")
            if request.content:
                try:
                    content = request.content.decode('utf-8')
                    self.logger.debug(f"Body: {content}")
                except:
                    self.logger.debug(f"Body: {request.content}")
            self.logger.debug("========================")

        response = await super().handle_async_request(request)

        if self.verbose and ('/connect/token' in url_str or '/connect/authorize' in url_str):
            self.logger.debug("=== Raw OAuth Response ===")
            self.logger.debug(f"Status: {response.status_code}")
            try:
                self.logger.debug(f"Body: {response.content.decode('utf-8')}")
            except:
                self.logger.debug(f"Body: {response.content}")
            self.logger.debug("=========================")

        return response


def create_self_signed_cert():
    """Create a self-signed certificate for localhost"""
    private_key = rsa.generate_private_key(
        public_exponent=65537,
        key_size=2048,
    )

    subject = issuer = x509.Name([
        x509.NameAttribute(NameOID.COMMON_NAME, u"localhost"),
    ])

    cert = x509.CertificateBuilder().subject_name(
        subject
    ).issuer_name(
        issuer
    ).public_key(
        private_key.public_key()
    ).serial_number(
        x509.random_serial_number()
    ).not_valid_before(
        datetime.datetime.now(datetime.timezone.utc)
    ).not_valid_after(
        datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)
    ).add_extension(
        x509.SubjectAlternativeName([
            x509.DNSName(u"localhost"),
            x509.IPAddress(ipaddress.IPv4Address(u"127.0.0.1")),
        ]),
        critical=False,
    ).sign(private_key, hashes.SHA256())

    return private_key, cert


class HTTPSCallbackServer:
    """HTTPS callback server for OAuth flow"""

    def __init__(self, port=CALLBACK_PORT):
        self.port = port
        self.auth_code = None
        self.server = None
        self.logger = logging.getLogger(f"{__name__}.HTTPSCallbackServer")

    async def callback_handler(self, request):
        query_params = dict(request.query_params)
        if 'code' in query_params:
            self.auth_code = query_params['code']
            return PlainTextResponse("Authorization successful! You can close this window.")
        return PlainTextResponse("Authorization failed!")

    async def start_server(self):
        app = Starlette(routes=[
            Route("/callback", self.callback_handler, methods=["GET"])
        ])

        private_key, cert = create_self_signed_cert()

        with tempfile.TemporaryDirectory() as temp_dir:
            cert_file = Path(temp_dir) / "cert.pem"
            key_file = Path(temp_dir) / "key.pem"

            with open(cert_file, "wb") as f:
                f.write(cert.public_bytes(serialization.Encoding.PEM))

            with open(key_file, "wb") as f:
                f.write(private_key.private_bytes(
                    encoding=serialization.Encoding.PEM,
                    format=serialization.PrivateFormat.PKCS8,
                    encryption_algorithm=serialization.NoEncryption()
                ))

            config = uvicorn.Config(
                app,
                host="127.0.0.1",
                port=self.port,
                ssl_keyfile=str(key_file),
                ssl_certfile=str(cert_file),
                log_level="critical",  # Suppress all uvicorn logging
                access_log=False
            )
            self.server = uvicorn.Server(config)
            await self.server.serve()

    async def wait_for_callback(self, timeout=300):
        """Wait for the OAuth callback with timeout"""
        start_time = asyncio.get_event_loop().time()
        while self.auth_code is None:
            if asyncio.get_event_loop().time() - start_time > timeout:
                raise TimeoutError("OAuth callback timeout")
            await asyncio.sleep(0.1)
        return self.auth_code

    async def shutdown(self):
        """Gracefully shutdown the server"""
        if self.server:
            self.logger.debug("Shutting down callback server")
            self.server.should_exit = True


def create_insecure_http_client(
    headers: dict[str, str] | None = None,
    timeout: httpx.Timeout | None = None,
    auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
    """Create HTTP client that disables SSL verification for self-signed certs"""
    kwargs = {
        "follow_redirects": True,
        "verify": False,
        "timeout": timeout or httpx.Timeout(30.0),
    }
    if headers:
        kwargs["headers"] = headers
    if auth:
        kwargs["auth"] = auth
    return httpx.AsyncClient(**kwargs)


class AuthenticatedA2AClient:
    """
    A2A client with integrated OAuth authentication.

    This class handles OAuth token acquisition and A2A server communication.
    """

    def __init__(self, server_url: str, verbose: bool = False):
        """
        Initialize the authenticated MCP client.

        Args:
            server_url: The MCP server URL
            verbose: Enable verbose logging for OAuth requests
        """
        self.server_url = server_url
        self.verbose = verbose
        self.access_token: Optional[str] = None
        self.token_response: Optional[Dict[str, Any]] = None

        self.logger = logging.getLogger(f"{__name__}.AuthenticatedA2AClient")

    async def authenticate(self) -> bool:
        """
        Perform OAuth authentication flow.

        Returns:
            True if authentication successful, False otherwise
        """
        self.logger.info("Starting OAuth authentication flow...")

        redirect_uri = f"https://127.0.0.1:{CALLBACK_PORT}/callback"

        # Step 1: Call MCP server to get WWW-Authenticate header
        limits = httpx.Limits(max_connections=1, max_keepalive_connections=0)
        async with httpx.AsyncClient(
            verify=False,
            timeout=30.0,
            limits=limits,
            http2=False
        ) as http_client:
            self.logger.info("Discovering OAuth endpoints from MCP server...")
            mcp_response = await http_client.get(self.server_url)

            if mcp_response.status_code != 401:
                self.logger.error(f"Expected 401 from MCP server, got {mcp_response.status_code}")
                return False

            www_auth = mcp_response.headers.get("WWW-Authenticate")
            if not www_auth:
                self.logger.error("No WWW-Authenticate header found")
                return False

            if self.verbose:
                self.logger.debug(f"WWW-Authenticate: {www_auth}")

            # Parse resource metadata URL
            resource_metadata_url = None
            if "resource_metadata=" in www_auth:
                start = www_auth.find('resource_metadata="') + len('resource_metadata="')
                end = www_auth.find('"', start)
                resource_metadata_url = www_auth[start:end]
            else:
                self.logger.error("Could not parse resource_metadata from WWW-Authenticate header")
                return False

            self.logger.info(f"Found resource metadata URL: {resource_metadata_url}")

            # Get resource metadata
            metadata_response = await http_client.get(resource_metadata_url)
            if metadata_response.status_code != 200:
                self.logger.warning(f"Resource metadata failed: {metadata_response.status_code}")
                # Try MCP server's own metadata endpoint
                mcp_metadata_url = f"{self.server_url}/.well-known/oauth-protected-resource"
                self.logger.info(f"Trying MCP server metadata: {mcp_metadata_url}")
                metadata_response = await http_client.get(mcp_metadata_url)

                if metadata_response.status_code != 200:
                    self.logger.error(f"MCP metadata also failed: {metadata_response.status_code}")
                    return False

            metadata = metadata_response.json()
            auth_servers = metadata.get("authorization_servers", [])
            if not auth_servers:
                self.logger.error("No authorization servers found in resource metadata")
                return False

            auth_server = auth_servers[0]
            self.logger.info(f"Found authorization server: {auth_server}")

            # Discover OAuth endpoints
            discovery_url = f"{auth_server}.well-known/openid-configuration"
            self.logger.info(f"Discovering endpoints from: {discovery_url}")

            discovery_response = await http_client.get(discovery_url)
            if discovery_response.status_code != 200:
                self.logger.error(f"Discovery failed: {discovery_response.status_code}")
                return False

            discovery_data = discovery_response.json()
            registration_url = discovery_data.get("registration_endpoint")
            authorization_url = discovery_data.get("authorization_endpoint")
            token_url = discovery_data.get("token_endpoint")

            if not registration_url:
                self.logger.error("No registration endpoint found - DCR may not be supported")
                return False

            self.logger.info(f"Found registration endpoint: {registration_url}")

        # Step 2: Start callback server
        callback_server = HTTPSCallbackServer(CALLBACK_PORT)

        # Wrap server startup to suppress all errors and logging
        async def run_server_with_complete_suppression():
            # Temporarily suppress all logging from uvicorn
            uvicorn_logger = logging.getLogger("uvicorn")
            starlette_logger = logging.getLogger("starlette")
            original_uvicorn_level = uvicorn_logger.level
            original_starlette_level = starlette_logger.level

            try:
                # Set logging to CRITICAL to suppress all output
                uvicorn_logger.setLevel(logging.CRITICAL)
                starlette_logger.setLevel(logging.CRITICAL)

                await callback_server.start_server()
            except:
                # Suppress all errors completely
                pass
            finally:
                # Restore original logging levels
                uvicorn_logger.setLevel(original_uvicorn_level)
                starlette_logger.setLevel(original_starlette_level)

        server_task = asyncio.create_task(run_server_with_complete_suppression())
        await asyncio.sleep(1)

        try:
            # Step 3: Dynamic Client Registration
            registration_data = {
                "client_name": "A2A Client",
                "redirect_uris": [redirect_uri],
                "grant_types": ["authorization_code", "refresh_token"],
                "response_types": ["code"],
                "token_endpoint_auth_method": "client_secret_basic",
                "scope": "openid profile email offline_access avalara_mcp"
            }

            client_id = None
            client_secret = None

            for reg_attempt in range(10):
                async with httpx.AsyncClient(
                    verify=False,
                    timeout=30.0,
                    limits=limits,
                    http2=False
                ) as http_client:
                    self.logger.info(f"Registering OAuth client (attempt {reg_attempt + 1}/10)...")
                    response = await http_client.post(registration_url, json=registration_data)

                    if response.status_code != 201:
                        if reg_attempt < 9:
                            self.logger.warning(f"DCR failed: {response.status_code} {response.text}")
                            await asyncio.sleep(3)
                            continue
                        else:
                            self.logger.error(f"DCR failed after all attempts: {response.status_code}")
                            return False

                    client_info = response.json()
                    client_id = client_info["client_id"]
                    client_secret = client_info["client_secret"]

                    self.logger.info(f"Client registered: {client_id}")

                    if '+' in client_secret or '%' in client_secret:
                        self.logger.warning(f"Client secret contains URL encoding characters, retrying...")
                        if reg_attempt < 9:
                            await asyncio.sleep(3)
                            continue
                        else:
                            self.logger.warning("Warning: Client secret still contains URL encoding")

                    break

            # Step 4: Generate PKCE parameters
            code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
            code_challenge = base64.urlsafe_b64encode(
                hashlib.sha256(code_verifier.encode('utf-8')).digest()
            ).decode('utf-8').rstrip('=')

            # Step 5: Create OAuth client
            logging_transport = LoggingTransport(verbose=self.verbose)
            oauth_http_client = httpx.AsyncClient(transport=logging_transport)

            client = AsyncOAuth2Client(
                client_id=client_id,
                client_secret=client_secret,
                redirect_uri=redirect_uri,
                httpx_client=oauth_http_client,
            )

            # Step 6: Generate authorization URL
            auth_url, state = client.create_authorization_url(
                authorization_url,
                code_challenge=code_challenge,
                scope="openid profile email offline_access avalara_mcp",
                code_challenge_method="S256",
            )

            self.logger.info(f"Opening browser to: {auth_url}")
            webbrowser.open(auth_url)

            # Step 7: Wait for callback
            self.logger.info("Waiting for authorization callback...")
            auth_code = await callback_server.wait_for_callback()

            # Step 8: Exchange code for token
            for attempt in range(10):
                try:
                    fresh_transport = LoggingTransport(verbose=self.verbose)
                    fresh_http_client = httpx.AsyncClient(transport=fresh_transport)
                    fresh_client = AsyncOAuth2Client(
                        client_id=client_id,
                        client_secret=client_secret,
                        redirect_uri=redirect_uri,
                        httpx_client=fresh_http_client,
                    )

                    token_response = await fresh_client.fetch_token(
                        token_url,
                        authorization_response=f"{redirect_uri}?code={auth_code}",
                        code_verifier=code_verifier
                    )

                    await fresh_http_client.aclose()
                    break
                except Exception as e:
                    if attempt < 9:
                        self.logger.warning(f"Token exchange failed (attempt {attempt + 1}/10): {e}")
                        await asyncio.sleep(3)
                    else:
                        raise

            self.access_token = token_response['access_token']
            self.token_response = token_response
            self.logger.info("OAuth authentication successful!")
            await oauth_http_client.aclose()
            return True

        finally:
            # Aggressively shutdown the callback server with complete error suppression
            await callback_server.shutdown()
            server_task.cancel()

            # Suppress all stderr output during shutdown
            with suppress_stderr():
                try:
                    # Give it a very short time to cleanup, then move on
                    await asyncio.wait_for(server_task, timeout=0.1)
                except:
                    # Ignore all errors including timeouts and cancellations
                    pass

    async def ask(self, query:str) -> bool:
        async with httpx.AsyncClient() as httpx_client:
            resolver = A2ACardResolver(
                httpx_client=httpx_client,
                base_url=AVI_A2A_URL)
            final_agent_card_to_use: AgentCard | None = None
            _public_card = (
                await resolver.get_agent_card()
            )
            final_agent_card_to_use = _public_card
            auth_headers_dict = {
                        'Authorization': f'Bearer {self.access_token}'
                    }
            client = A2AClient(httpx_client=httpx_client, agent_card=final_agent_card_to_use)
            send_message_payload: dict[str, Any] = {
              'message': {
                  'role': 'user',
                  'parts': [
                      {'kind': 'text', 'text': query},
                  ],
                  'messageId': uuid4().hex,
              },
            }
            streaming_request = SendStreamingMessageRequest(id=str(uuid4()), params=MessageSendParams(**send_message_payload))

            stream_response = client.send_message_streaming(streaming_request,
                                                            http_kwargs={'headers': auth_headers_dict})


            async for chunk in stream_response:
                resp = chunk.model_dump(mode='json', exclude_none=True).get("result",{})
                who = resp.get("status",{}).get("message",{}).get("role","unknown")
                what = "\n".join([part["text"] for part in resp.get("status",{}).get("message",{}).get("parts",[])])
                logger.info(f"\n{who}: {what}")


        return "", True

    async def close(self):
        """Close the client connection."""
        pass

    async def __aenter__(self):
        """Async context manager entry."""
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit."""
        await self.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test OAuth authentication flow for A2A server")
    parser.add_argument(
        "product",
        nargs="?",
        default=DEFAULT_PRODUCT,
        help=f"Product name to test (default: {DEFAULT_PRODUCT})"
    )
    parser.add_argument(
        "--verbose", "-v",
        action="store_true",
        help="Enable verbose OAuth logging"
    )

    args = parser.parse_args()

    # Update the main function to also accept verbose parameter
    async def run_main():
        server_url = get_server_url(args.product)
        logger.info(f"Testing auth flow for product '{args.product}' at: {server_url}")

        async with AuthenticatedA2AClient(server_url, verbose=args.verbose) as client:
            # Step 1: Authenticate
            logger.info("Step 1: Starting OAuth authentication...")
            if not await client.authenticate():
                logger.error("Authentication failed!")
                return False

            logger.info("✓ OAuth authentication successful!")

            # Step 2: Connect to A2A server
            logger.info("Step 2: Connecting to A2A server...")
            response, success = await client.ask("what agents do you have?")
            if not success:
                logger.error("Connection failed!")
                return False
            else:
                logger.info(f"A2A server response: {response}")

            logger.info("✓ A2A server connection successful!")
            logger.info("Auth flow test completed successfully!")
            return True

    success = asyncio.run(run_main())
    if success:
        logger.info("🎉 Auth flow test PASSED")
    else:
        logger.error("❌ Auth flow test FAILED")


