"""End-to-end tests for JWT algorithm validation.

Tests algorithm validation to prevent algorithm confusion attacks and ensure
only secure asymmetric algorithms are accepted.
"""

import json
import time
import pytest
import base64
from flask import Flask, jsonify
from jwcrypto import jwk
from jwcrypto import jwt as jwcrypto_jwt
from axioms_flask.decorators import has_valid_access_token, has_required_scopes
from axioms_flask.error import AxiomsError


# Generate RSA key pair for testing
def generate_test_keys():
    """Generate RSA key pair for JWT signing and verification."""
    key = jwk.JWK.generate(kty='RSA', size=2048, kid='test-key-id')
    return key


# Generate JWT token
def generate_jwt_token(key, claims, alg='RS256'):
    """Generate a JWT token with specified claims and algorithm."""
    token = jwcrypto_jwt.JWT(
        header={"alg": alg, "kid": key.kid},
        claims=claims
    )
    token.make_signed_token(key)
    return token.serialize()


# Create malformed token with unsupported algorithm
def create_token_with_none_alg(claims):
    """Create a token with 'none' algorithm (security vulnerability)."""
    header = {"alg": "none", "typ": "JWT", "kid": "test-key-id"}
    header_b64 = base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip('=')
    payload_b64 = base64.urlsafe_b64encode(json.dumps(claims).encode()).decode().rstrip('=')
    # 'none' algorithm has empty signature
    return f"{header_b64}.{payload_b64}."


# Create test Flask application
@pytest.fixture
def app():
    """Create Flask test application with protected routes."""
    flask_app = Flask(__name__)

    # Configuration
    flask_app.config['TESTING'] = True
    flask_app.config['AXIOMS_AUDIENCE'] = 'test-audience'
    flask_app.config['AXIOMS_JWKS_URL'] = 'https://test-domain.com/.well-known/jwks.json'

    # Error handler
    @flask_app.errorhandler(AxiomsError)
    def handle_axioms_error(error):
        return jsonify(error.error), error.status_code

    # Create test endpoints
    @flask_app.route('/private', methods=['GET'])
    @has_valid_access_token
    @has_required_scopes(['openid', 'profile'])
    def api_private():
        return jsonify({'message': 'Private endpoint'})

    return flask_app


@pytest.fixture
def client(app):
    """Create Flask test client."""
    return app.test_client()


@pytest.fixture
def test_key():
    """Generate test RSA key."""
    return generate_test_keys()


@pytest.fixture
def mock_jwks_data(test_key):
    """Generate mock JWKS data."""
    public_key = test_key.export_public(as_dict=True)
    jwks = {'keys': [public_key]}
    return json.dumps(jwks).encode('utf-8')


@pytest.fixture(autouse=True)
def mock_jwks_fetch(monkeypatch, mock_jwks_data):
    """Mock JWKS fetch to return test keys."""
    from axioms_flask import token

    class MockCacheFetcher:
        def fetch(self, url, max_age=300):
            return mock_jwks_data

    monkeypatch.setattr(token, 'CacheFetcher', MockCacheFetcher)


# Test classes
class TestAlgorithmValidation:
    """Test JWT algorithm validation for security."""

    def test_valid_rs256_algorithm(self, client, test_key):
        """Test that RS256 algorithm is accepted."""
        now = int(time.time())
        claims = json.dumps({
            'sub': 'user123',
            'aud': ['test-audience'],
            'scope': 'openid profile',
            'exp': now + 3600,
            'iat': now
        })

        token = generate_jwt_token(test_key, claims, alg='RS256')
        response = client.get('/private', headers={'Authorization': f'Bearer {token}'})
        assert response.status_code == 200

    def test_reject_none_algorithm(self, client):
        """Test that 'none' algorithm is rejected (critical security test)."""
        now = int(time.time())
        claims = {
            'sub': 'user123',
            'aud': ['test-audience'],
            'scope': 'openid profile',
            'exp': now + 3600,
            'iat': now
        }

        token = create_token_with_none_alg(claims)
        response = client.get('/private', headers={'Authorization': f'Bearer {token}'})
        assert response.status_code == 401
        data = json.loads(response.data)
        assert data['error'] == 'unauthorized_access'
        assert 'algorithm' in data['error_description'].lower()

    def test_reject_hs256_symmetric_algorithm(self, client, test_key):
        """Test that symmetric algorithms like HS256 are rejected."""
        # Try to create a token with HS256 (symmetric algorithm)
        # This should fail during token generation or be rejected during validation
        now = int(time.time())
        claims = {
            'sub': 'user123',
            'aud': ['test-audience'],
            'scope': 'openid profile',
            'exp': now + 3600,
            'iat': now
        }

        # Create a token with manipulated header claiming HS256
        header = {"alg": "HS256", "typ": "JWT", "kid": "test-key-id"}
        header_b64 = base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip('=')
        payload_b64 = base64.urlsafe_b64encode(json.dumps(claims).encode()).decode().rstrip('=')
        # Add fake signature
        signature_b64 = base64.urlsafe_b64encode(b'fake_signature').decode().rstrip('=')
        token = f"{header_b64}.{payload_b64}.{signature_b64}"

        response = client.get('/private', headers={'Authorization': f'Bearer {token}'})
        assert response.status_code == 401
        data = json.loads(response.data)
        assert data['error'] == 'unauthorized_access'
        assert 'algorithm' in data['error_description'].lower()

    def test_reject_missing_algorithm(self, client):
        """Test that tokens without algorithm are rejected."""
        now = int(time.time())
        claims = {
            'sub': 'user123',
            'aud': ['test-audience'],
            'scope': 'openid profile',
            'exp': now + 3600,
            'iat': now
        }

        # Create token without 'alg' header
        header = {"typ": "JWT", "kid": "test-key-id"}  # Missing 'alg'
        header_b64 = base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip('=')
        payload_b64 = base64.urlsafe_b64encode(json.dumps(claims).encode()).decode().rstrip('=')
        signature_b64 = base64.urlsafe_b64encode(b'fake_signature').decode().rstrip('=')
        token = f"{header_b64}.{payload_b64}.{signature_b64}"

        response = client.get('/private', headers={'Authorization': f'Bearer {token}'})
        assert response.status_code == 401
        data = json.loads(response.data)
        assert data['error'] == 'unauthorized_access'

    def test_reject_missing_kid(self, client, test_key):
        """Test that tokens without key ID are rejected."""
        now = int(time.time())
        claims = json.dumps({
            'sub': 'user123',
            'aud': ['test-audience'],
            'scope': 'openid profile',
            'exp': now + 3600,
            'iat': now
        })

        # Create token using jwcrypto but manually remove kid from header
        token_obj = jwcrypto_jwt.JWT(
            header={"alg": "RS256"},  # Missing 'kid'
            claims=claims
        )
        token_obj.make_signed_token(test_key)
        token = token_obj.serialize()

        response = client.get('/private', headers={'Authorization': f'Bearer {token}'})
        assert response.status_code == 401
        data = json.loads(response.data)
        assert data['error'] == 'unauthorized_access'
        assert 'key id' in data['error_description'].lower() or 'kid' in data['error_description'].lower()

    def test_allowed_algorithms_coverage(self, client, test_key):
        """Test that all allowed asymmetric algorithms are accepted."""
        # Note: RS256 is tested separately, this tests the pattern
        # In practice, we'd need different keys for different algorithms

        now = int(time.time())
        claims = json.dumps({
            'sub': 'user123',
            'aud': ['test-audience'],
            'scope': 'openid profile',
            'exp': now + 3600,
            'iat': now
        })

        # Test RS256 (our test key is RSA)
        token = generate_jwt_token(test_key, claims, alg='RS256')
        response = client.get('/private', headers={'Authorization': f'Bearer {token}'})
        assert response.status_code == 200

        # Note: To test other algorithms (RS384, RS512, ES256, etc.) we would need
        # to generate appropriate keys and update JWKS. The validation logic
        # supports them via ALLOWED_ALGORITHMS constant.

    def test_invalid_jwt_format(self, client):
        """Test that malformed JWT tokens are rejected."""
        # Token with invalid format (missing parts)
        token = "invalid.token"

        response = client.get('/private', headers={'Authorization': f'Bearer {token}'})
        assert response.status_code == 401
        data = json.loads(response.data)
        assert data['error'] == 'unauthorized_access'
