#!/usr/bin/env python3
"""
Simple zone list server for DNS Editor app.
Returns JSON list of primary zones from BIND9.

Usage:
    ./zone_server.py [--port PORT] [--bind ADDRESS] [--cert-dir DIR]

Examples:
    # HTTP (development only)
    ./zone_server.py --port 8053 --bind 0.0.0.0

    # HTTPS with Let's Encrypt certificates
    ./zone_server.py --port 443 --cert-dir /etc/letsencrypt/live/nameserver.example.com
"""

import subprocess
import json
import argparse
import socket
import ssl
import os
import sys
from http.server import HTTPServer, BaseHTTPRequestHandler


def get_primary_zones():
    """Get list of primary (master) zones from BIND9."""
    try:
        result = subprocess.run(
            ['named-checkconf', '-l'],
            capture_output=True,
            text=True,
            timeout=10
        )

        if result.returncode != 0:
            return {'error': 'named-checkconf failed', 'forward': [], 'reverse': []}

        forward = []
        reverse = []

        for line in result.stdout.strip().split('\n'):
            if not line:
                continue
            parts = line.split()
            if len(parts) >= 2 and parts[-1] == 'master':
                zone_name = parts[0]
                # Skip any internal zones you want to hide
                if '.ivar' in zone_name:
                    continue
                if 'arpa' in zone_name.lower():
                    reverse.append(zone_name)
                else:
                    forward.append(zone_name)

        forward.sort()
        reverse.sort()

        return {'forward': forward, 'reverse': reverse}

    except subprocess.TimeoutExpired:
        return {'error': 'timeout', 'forward': [], 'reverse': []}
    except FileNotFoundError:
        return {'error': 'named-checkconf not found', 'forward': [], 'reverse': []}
    except Exception as e:
        return {'error': str(e), 'forward': [], 'reverse': []}


class ZoneRequestHandler(BaseHTTPRequestHandler):
    """HTTP request handler for zone list requests."""

    def log_message(self, format, *args):
        """Log requests to stdout."""
        print(f"{self.address_string()} - {format % args}")

    def send_json(self, data, status=200):
        """Send JSON response."""
        body = json.dumps(data, indent=2).encode('utf-8')
        self.send_response(status)
        self.send_header('Content-Type', 'application/json')
        self.send_header('Content-Length', len(body))
        self.send_header('Access-Control-Allow-Origin', '*')
        self.end_headers()
        self.wfile.write(body)

    def do_GET(self):
        """Handle GET requests."""
        if self.path in ('/', '/zones', '/zones/'):
            zones = get_primary_zones()
            self.send_json(zones)
        else:
            self.send_json({'error': 'Not found'}, 404)

    def do_OPTIONS(self):
        """Handle CORS preflight."""
        self.send_response(200)
        self.send_header('Access-Control-Allow-Origin', '*')
        self.send_header('Access-Control-Allow-Methods', 'GET, OPTIONS')
        self.send_header('Access-Control-Allow-Headers', 'Content-Type')
        self.end_headers()


def create_ssl_context(cert_dir):
    """Create SSL context from Let's Encrypt certificate files."""
    cert_file = os.path.join(cert_dir, 'fullchain.pem')
    key_file = os.path.join(cert_dir, 'privkey.pem')

    # Check that certificate files exist
    if not os.path.exists(cert_file):
        print(f"Error: Certificate file not found: {cert_file}", file=sys.stderr)
        sys.exit(1)
    if not os.path.exists(key_file):
        print(f"Error: Private key file not found: {key_file}", file=sys.stderr)
        sys.exit(1)

    context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    context.load_cert_chain(cert_file, key_file)

    # Set secure defaults
    context.minimum_version = ssl.TLSVersion.TLSv1_2

    return context


def main():
    parser = argparse.ArgumentParser(description='Zone list server for DNS Editor')
    parser.add_argument('--port', '-p', type=int, default=8053,
                        help='Port to listen on (default: 8053, or 443 for HTTPS)')
    parser.add_argument('--bind', '-b', default='0.0.0.0',
                        help='Address to bind to (default: 0.0.0.0)')
    parser.add_argument('--cert-dir', '-c', default=None,
                        help='Directory containing fullchain.pem and privkey.pem (enables HTTPS)')
    args = parser.parse_args()

    # Determine if using HTTPS
    use_https = args.cert_dir is not None

    # Default port based on protocol
    port = args.port
    if port == 8053 and use_https:
        port = 443
        print("Note: Using default HTTPS port 443 (override with --port)")

    server_address = (args.bind, port)
    httpd = HTTPServer(server_address, ZoneRequestHandler)

    # Configure SSL if certificate directory provided
    if use_https:
        ssl_context = create_ssl_context(args.cert_dir)
        httpd.socket = ssl_context.wrap_socket(httpd.socket, server_side=True)
        protocol = "https"
    else:
        protocol = "http"
        print("Warning: Running without HTTPS. Use --cert-dir for production.", file=sys.stderr)

    print(f"Zone server listening on {protocol}://{args.bind}:{port}/")
    print(f"Endpoints:")
    print(f"  GET /       - List all primary zones")
    print(f"  GET /zones  - List all primary zones")
    if use_https:
        print(f"Certificate directory: {args.cert_dir}")
    print()
    print("Press Ctrl+C to stop")

    try:
        httpd.serve_forever()
    except KeyboardInterrupt:
        print("\nShutting down...")
        httpd.shutdown()


if __name__ == '__main__':
    main()
