import os
import PyPDF2
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
import io, re

import fitz  # PyMuPDF
from PIL import Image
import tempfile
import shutil
import logging
from typing import Union, Tuple

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def merge_pdfs_with_index(directory_path, output_filename="merged_document.pdf", add_page_numbers=True):
    """
    Merge all PDF files in a directory with an index table at the beginning.
    
    Args:
        directory_path (str): Path to directory containing PDF files
        output_filename (str): Name of the output merged PDF file
        add_page_numbers (bool): Whether to add visible page numbers to each page
    """
    
    # Get all PDF files from directory
    pdf_files = [f for f in sorted(os.listdir(directory_path)) if f.lower().endswith('.pdf')]
    if not pdf_files:
        logger.info("No PDF files found in the directory.")
        return
    
    # Sort files for consistent ordering
    # pdf_files.sort()
    
    # Track document info for index
    document_info = []
    current_page = 2  # Start from page 2 (index will be page 1)
    
    # First pass: collect info for index
    for pdf_file in pdf_files:
        pdf_path = os.path.join(directory_path, pdf_file)
        try:
            with open(pdf_path, 'rb') as file:
                pdf_reader = PyPDF2.PdfReader(file)
                num_pages = len(pdf_reader.pages)
                
                document_info.append({
                    'name': os.path.splitext(pdf_file)[0],  # Remove .pdf extension
                    'start_page': current_page,
                    'end_page': current_page + num_pages - 1,
                    'file_path': pdf_path
                })
                
                current_page += num_pages
        except Exception as e:
            logger.info(f"Error reading {pdf_file}: {e}")
            continue
    
    if not document_info:
        logger.info("No valid PDF files could be processed.")
        return
    
    # Create index page
    index_pdf = create_index_page(document_info)
    
    # Create final merged PDF
    pdf_writer = PyPDF2.PdfWriter()
    
    # Add index page first
    index_reader = PyPDF2.PdfReader(index_pdf)
    pdf_writer.add_page(index_reader.pages[0])
    
    # Add all other PDFs with optional page numbering
    page_number = 2  # Start numbering from page 2
    
    for doc_info in document_info:
        try:
            with open(doc_info['file_path'], 'rb') as file:
                pdf_reader = PyPDF2.PdfReader(file)
                
                for page in pdf_reader.pages:
                    if add_page_numbers:
                        # Add page number to the page
                        numbered_page = add_page_number_to_page(page, page_number)
                        pdf_writer.add_page(numbered_page)
                    else:
                        pdf_writer.add_page(page)
                    page_number += 1
        except Exception as e:
            logger.info(f"Error processing {doc_info['name']}: {e}")
            continue
    
    # Save the merged PDF
    output_path = os.path.join(directory_path, output_filename)
    try:
        with open(output_path, 'wb') as output_file:
            pdf_writer.write(output_file)
        
        logger.info(f"✅ Merged PDF created successfully: {output_path}")
        logger.info(f"📄 Total pages: {len(pdf_writer.pages)}")
        logger.info(f"📑 Documents merged: {len(document_info)}")
        
        # Print summary
        logger.info("\n📋 Document Summary:")
        for doc in document_info:
            if doc['start_page'] == doc['end_page']:
                pages_text = f"Page {doc['start_page']}"
            else:
                pages_text = f"Pages {doc['start_page']}-{doc['end_page']}"
            logger.info(f"  • {doc['name']}: {pages_text}")
            
    except Exception as e:
        logger.info(f"Error saving merged PDF: {e}")

def create_index_page(document_info):
    """Create a table of contents page."""
    buffer = io.BytesIO()
    
    # Create canvas
    c = canvas.Canvas(buffer, pagesize=letter)
    width, height = letter
    
    # Title
    c.setFont("Helvetica-Bold", 18)
    title_width = c.stringWidth("Table of Contents", "Helvetica-Bold", 18)
    c.drawString((width - title_width) / 2, height - 60, "Table of Contents")
    
    # Add a line under title
    c.setStrokeColorRGB(0.3, 0.3, 0.3)
    c.setLineWidth(1)
    c.line(50, height - 80, width - 50, height - 80)
    
    # Headers
    c.setFont("Helvetica-Bold", 12)
    c.setFillColorRGB(0, 0, 0)
    c.drawString(60, height - 110, "Document Name")
    c.drawString(width - 150, height - 110, "Pages")
    
    # Draw line under headers
    c.setLineWidth(0.5)
    c.line(50, height - 120, width - 50, height - 120)
    
    # Document entries
    c.setFont("Helvetica", 11)
    y_position = height - 145
    
    for i, doc in enumerate(document_info):
        if y_position < 80:  # If we're running out of space
            c.drawString(60, y_position, "... (additional documents)")
            break
        
        # Alternate row background
        if i % 2 == 0:
            c.setFillColorRGB(0.95, 0.95, 0.95)
            c.rect(50, y_position - 3, width - 100, 18, fill=1, stroke=0)
        
        c.setFillColorRGB(0, 0, 0)
        
        # Document name (truncate if too long)
        doc_name = re.sub(r'^\d+_', '', doc['name']) #Removing the initial number from the string
        if len(doc_name) > 40:
            doc_name = doc_name[:37] + "..."
        c.drawString(60, y_position, doc_name)
        
        # Page range
        if doc['start_page'] == doc['end_page']:
            page_text = str(doc['start_page'])
        else:
            page_text = f"{doc['start_page']}-{doc['end_page']}"
        
        # Right-align page numbers
        page_width = c.stringWidth(page_text, "Helvetica", 11)
        c.drawString(width - 100 - page_width, y_position, page_text)
        
        y_position -= 22
    
    # Add footer with generation info
    c.setFont("Helvetica", 8)
    c.setFillColorRGB(0.5, 0.5, 0.5)
    c.drawString(60, 40, f"Generated index - Total documents: {len(document_info)}")
    
    # Add page number to index page
    c.drawString(width - 80, 40, "Page 1")
    
    c.save()
    buffer.seek(0)
    return buffer

def add_page_number_to_page(page, page_number):
    """Add a page number to a PDF page."""
    # Create a new PDF with just the page number
    packet = io.BytesIO()
    can = canvas.Canvas(packet, pagesize=letter)
    
    # Get page dimensions
    page_width = float(page.mediabox.width)
    page_height = float(page.mediabox.height)
    
    # Add page number at bottom center
    can.setFont("Helvetica", 10)
    page_text = f"- {page_number} -"
    text_width = can.stringWidth(page_text, "Helvetica", 10)
    
    # Position at bottom center
    x_position = (page_width - text_width) / 2
    y_position = 30
    
    can.setFillColorRGB(0.3, 0.3, 0.3)
    can.drawString(x_position, y_position, page_text)
    can.save()
    
    # Move to the beginning of the StringIO buffer
    packet.seek(0)
    
    # Create a new PDF reader for the page number
    page_number_pdf = PyPDF2.PdfReader(packet)
    page_number_page = page_number_pdf.pages[0]
    
    # Merge the page number with the original page
    page.merge_page(page_number_page)
    
    return page







def compress_pdf(input_path: str, output_path: str = None, target_size_percent: int = 50, 
                max_attempts: int = 10, preserve_quality: bool = True) -> Tuple[bool, str, dict]:
    """
    Compress a PDF file to achieve a target file size reduction.
    
    Args:
        input_path (str): Path to the input PDF file
        output_path (str): Path for the compressed PDF (if None, adds '_compressed' to original name)
        target_size_percent (int): Target size as percentage of original (25, 50, or 75)
        max_attempts (int): Maximum compression attempts to reach target size
        preserve_quality (bool): Whether to prioritize quality over exact size target
    
    Returns:
        Tuple[bool, str, dict]: (success, output_path, compression_stats)
    """
    
    # Validate inputs
    if not os.path.exists(input_path):
        return False, "", {"error": "Input file does not exist"}
    
    if target_size_percent not in [25, 50, 75]:
        return False, "", {"error": "Target size must be 25%, 50%, or 75%"}
    
    # Get original file size
    original_size = os.path.getsize(input_path)
    target_size = original_size * (target_size_percent / 100)
    
    # Generate output path if not provided
    if output_path is None:
        base_name = os.path.splitext(input_path)[0]
        output_path = f"{base_name}_compressed_{target_size_percent}percent.pdf"
    
    logger.info(f"🗜️  Starting PDF compression...")
    logger.info(f"📄 Original size: {format_file_size(original_size)}")
    logger.info(f"🎯 Target size: {format_file_size(target_size)} ({target_size_percent}%)")
    
    try:
        # Open the PDF
        doc = fitz.open(input_path)
        
        # Compression stats
        stats = {
            "original_size": original_size,
            "target_size": target_size,
            "target_percent": target_size_percent,
            "attempts": 0,
            "final_size": 0,
            "compression_ratio": 0,
            "achieved_percent": 0,
            "method_used": "",
            "pages_processed": len(doc)
        }
        
        # Try different compression strategies
        success, final_path, method = attempt_compression_strategies(
            doc, output_path, target_size, target_size_percent, 
            max_attempts, preserve_quality, stats
        )
        
        if success:
            final_size = os.path.getsize(final_path)
            stats["final_size"] = final_size
            stats["compression_ratio"] = (original_size - final_size) / original_size * 100
            stats["achieved_percent"] = (final_size / original_size) * 100
            stats["method_used"] = method
            
            logger.info(f"✅ Compression successful!")
            logger.info(f"📊 Final size: {format_file_size(final_size)}")
            logger.info(f"📉 Compression: {stats['compression_ratio']:.1f}% reduction")
            logger.info(f"🎯 Achieved: {stats['achieved_percent']:.1f}% of original size")
            
            return True, final_path, stats
        else:
            return False, "", stats
            
    except Exception as e:
        return False, "", {"error": f"Compression failed: {str(e)}"}
    finally:
        if 'doc' in locals():
            doc.close()

def attempt_compression_strategies(doc, output_path, target_size, target_percent, 
                                 max_attempts, preserve_quality, stats):
    """Try different compression strategies to achieve target size."""
    
    current_doc = doc  # Start with original document
    
    strategies = [
        ("basic_compression", lambda d: basic_pdf_compression(d, output_path)),
        ("image_optimization", lambda d: compress_with_image_optimization(d, output_path, target_percent)),
        ("aggressive_compression", lambda d: aggressive_pdf_compression(d, output_path, target_percent)),
        ("hybrid_compression", lambda d: hybrid_compression_approach(d, output_path, target_percent))
    ]
    
    for strategy_name, strategy_func in strategies:
        try:
    
            logger.info(f"🔄 Trying {strategy_name.replace('_', ' ').title()}...")
            temp_output = strategy_func(current_doc)
            
            if temp_output and os.path.exists(temp_output):
                current_size = os.path.getsize(temp_output)
                stats["attempts"] += 1
                
                logger.info(f"   📏 Size achieved: {format_file_size(current_size)} ({(current_size/stats['original_size'])*100:.1f}%)")
                
                # Check if we've reached our target or if this is the best we can do
                if current_size <= target_size or not preserve_quality:
                    if temp_output != output_path:
                        shutil.move(temp_output, output_path)
                    if current_doc != doc:
                        current_doc.close()
                    return True, output_path, strategy_name
                
                # If we're close enough and preserving quality, accept it
                if preserve_quality and current_size <= target_size * 1.1:  # Within 10% of target
                    if temp_output != output_path:
                        shutil.move(temp_output, output_path)
                    if current_doc != doc:
                        current_doc.close()
                    return True, output_path, strategy_name

                # Update current_doc for next strategy (if continuing)
                if current_doc != doc:
                    current_doc.close()  # Close previous intermediate document
                current_doc = fitz.open(temp_output)  # Open output as new input

                # Clean up temporary file if not the final output
                if os.path.exists(temp_output) and strategy_name == "hybrid_compression":
                    shutil.move(temp_output, output_path)
                    return True, output_path, strategy_name
                    
        except Exception as e:
            logger.info(f"   ❌ {strategy_name} failed: {str(e)}")
            continue
            
    # Update current_doc for next strategy (if continuing)
    if current_doc != doc:
        current_doc.close()  # Close previous intermediate document
    return False, "", "no_method_successful"

def basic_pdf_compression(doc, output_path):
    """Basic PDF compression using PyMuPDF's built-in methods."""
    # Create a new document with compression
    new_doc = fitz.open()
    
    for page_num in range(len(doc)):
        page = doc[page_num]
        # Create a new page and insert the content
        new_page = new_doc.new_page(width=page.rect.width, height=page.rect.height)
        new_page.show_pdf_page(page.rect, doc, page_num)
    
    # Save with compression options
    new_doc.save(output_path, 
                 garbage=4,           # Garbage collection level
                 clean=True,          # Clean up PDF
                 deflate=True,        # Compress streams
                 deflate_images=True, # Compress images
                 deflate_fonts=True)  # Compress fonts
    
    new_doc.close()
    return output_path

def compress_with_image_optimization(doc, output_path, target_percent):
    """Compress PDF by optimizing images within it."""
    new_doc = fitz.open()
    
    # Determine image quality based on target compression
    if target_percent <= 25:
        img_quality = 30
        max_dimension = 800
    elif target_percent <= 50:
        img_quality = 50
        max_dimension = 1200
    else:  # 75%
        img_quality = 70
        max_dimension = 1600
    
    for page_num in range(len(doc)):
        page = doc[page_num]
        new_page = new_doc.new_page(width=page.rect.width, height=page.rect.height)
        
        # Get images on this page
        image_list = page.get_images()
        
        if image_list:
            # Process each image
            for img_index, img in enumerate(image_list):
                try:
                    xref = img[0]
                    base_image = doc.extract_image(xref)
                    image_bytes = base_image["image"]
                    
                    # Convert to PIL Image
                    pil_image = Image.open(io.BytesIO(image_bytes))
                    
                    # Resize if too large
                    if max(pil_image.size) > max_dimension:
                        ratio = max_dimension / max(pil_image.size)
                        new_size = tuple(int(dim * ratio) for dim in pil_image.size)
                        pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
                    
                    # Compress image
                    img_buffer = io.BytesIO()
                    if pil_image.mode in ['RGBA', 'LA']:
                        pil_image = pil_image.convert('RGB')
                    
                    pil_image.save(img_buffer, format='JPEG', quality=img_quality, optimize=True)
                    img_buffer.seek(0)
                    
                    # Replace image in document
                    doc._replace_image(xref, img_buffer.getvalue())
                    
                except Exception as e:
                    logger.info(f"   ⚠️  Could not process image {img_index}: {e}")
                    continue
        
        # Copy page content
        new_page.show_pdf_page(page.rect, doc, page_num)
    
    # Save with compression
    new_doc.save(output_path,
                 garbage=4,
                 clean=True,
                 deflate=True,
                 deflate_images=True,
                 deflate_fonts=True)
    
    new_doc.close()
    return output_path

def aggressive_pdf_compression(doc, output_path, target_percent):
    """Aggressive compression that may reduce quality significantly."""
    new_doc = fitz.open()
    
    # Very aggressive settings based on target
    if target_percent <= 25:
        dpi = 100
        img_quality = 20
    elif target_percent <= 50:
        dpi = 150
        img_quality = 35
    else:  # 75%
        dpi = 200
        img_quality = 55
    
    for page_num in range(len(doc)):
        page = doc[page_num]
        
        # Render page as image and re-insert (lossy but effective)
        mat = fitz.Matrix(dpi/72, dpi/72)  # Create scaling matrix
        pix = page.get_pixmap(matrix=mat)
        
        # Convert pixmap to PIL Image
        img_data = pix.tobytes("png")
        pil_image = Image.open(io.BytesIO(img_data))
        
        # Compress as JPEG
        img_buffer = io.BytesIO()
        if pil_image.mode in ['RGBA', 'LA']:
            pil_image = pil_image.convert('RGB')
        
        pil_image.save(img_buffer, format='JPEG', quality=img_quality, optimize=True)
        img_buffer.seek(0)
        
        # Create new page and insert compressed image
        new_page = new_doc.new_page(width=page.rect.width, height=page.rect.height)
        new_page.insert_image(page.rect, stream=img_buffer.getvalue())
    
    new_doc.save(output_path, garbage=4, clean=True, deflate=True)
    new_doc.close()
    return output_path

def hybrid_compression_approach(doc, output_path, target_percent):
    """Hybrid approach combining multiple compression techniques."""
    # First try basic compression
    temp_file = output_path.replace('.pdf', '_temp.pdf')
    basic_pdf_compression(doc, temp_file)
    
    # Check size after basic compression
    basic_size = os.path.getsize(temp_file)
    target_size = os.path.getsize(doc.name) * (target_percent / 100)
    
    if basic_size <= target_size:
        shutil.move(temp_file, output_path)
        return output_path
    
    # If not enough, apply image optimization
    doc_temp = fitz.open(temp_file)
    compress_with_image_optimization(doc_temp, output_path, target_percent)
    doc_temp.close()
    
    # Clean up
    if os.path.exists(temp_file):
        os.remove(temp_file)
    
    return output_path

def format_file_size(size_bytes):
    """Convert bytes to human readable format."""
    if size_bytes < 1024:
        return f"{size_bytes} B"
    elif size_bytes < 1024**2:
        return f"{size_bytes/1024:.1f} KB"
    elif size_bytes < 1024**3:
        return f"{size_bytes/(1024**2):.1f} MB"
    else:
        return f"{size_bytes/(1024**3):.1f} GB"

def batch_compress_pdfs(directory_path, target_size_percent=50, output_suffix="_compressed"):
    """Compress all PDF files in a directory."""
    pdf_files = [f for f in os.listdir(directory_path) if f.lower().endswith('.pdf')]
    
    if not pdf_files:
        logger.info("No PDF files found in directory.")
        return
    
    logger.info(f"🗂️  Found {len(pdf_files)} PDF files to compress")
    
    results = []
    for i, pdf_file in enumerate(pdf_files, 1):
        input_path = os.path.join(directory_path, pdf_file)
        base_name = os.path.splitext(pdf_file)[0]
        output_path = os.path.join(directory_path, f"{base_name}{output_suffix}_{target_size_percent}percent.pdf")
        
        logger.info(f"\n📄 Processing {i}/{len(pdf_files)}: {pdf_file}")
        
        success, final_path, stats = compress_pdf(input_path, output_path, target_size_percent)
        results.append({
            'file': pdf_file,
            'success': success,
            'stats': stats
        })
    
    # Print summary
    logger.info(f"\n{'='*50}")
    logger.info(f"🎯 Batch Compression Summary")
    logger.info(f"{'='*50}")
    
    successful = sum(1 for r in results if r['success'])
    total_original_size = sum(r['stats'].get('original_size', 0) for r in results if r['success'])
    total_final_size = sum(r['stats'].get('final_size', 0) for r in results if r['success'])
    
    logger.info(f"✅ Successful: {successful}/{len(results)}")
    logger.info(f"💾 Total size reduction: {format_file_size(total_original_size - total_final_size)}")
    logger.info(f"📊 Overall compression: {((total_original_size - total_final_size) / total_original_size * 100):.1f}%")