import os
import json
import datetime
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import networkx as nx
from matplotlib.colors import LinearSegmentedColormap
from typing import Dict, List, Any

class ExecutionVisualizer:
    """Utility for visualizing CrewAI-Playwright execution flows"""
    
    def __init__(self, execution_data: Dict[str, Any]):
        """
        Initialize the visualizer
        
        Args:
            execution_data: The execution data from TaskExecutionManager
        """
        self.execution_data = execution_data
        self.output_dir = "visualizations"
        
        # Create output directory if it doesn't exist
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
    
    def generate_visualizations(self):
        """
        Generate all visualizations for the execution data
        
        Returns:
            Dictionary with paths to generated visualizations
        """
        output_files = {}
        
        # Generate agent interaction graph
        if "flow_visualization" in self.execution_data:
            vis_data = self.execution_data["flow_visualization"]
            
            # Generate agent interaction graph
            graph_path = self._generate_agent_graph(vis_data)
            if graph_path:
                output_files["agent_graph"] = graph_path
            
            # Generate timeline visualization
            timeline_path = self._generate_timeline(vis_data)
            if timeline_path:
                output_files["timeline"] = timeline_path
            
            # Generate execution summary
            summary_path = self._generate_execution_summary()
            if summary_path:
                output_files["summary"] = summary_path
        
        return output_files
    
    def _generate_agent_graph(self, visualization_data: Dict[str, Any]) -> str:
        """
        Generate a graph visualization of agent interactions
        
        Args:
            visualization_data: The flow visualization data
            
        Returns:
            Path to the generated image file
        """
        try:
            agents = visualization_data.get("agents", [])
            connections = visualization_data.get("connections", [])
            
            if not agents or not connections:
                return ""
            
            # Create a directed graph
            G = nx.DiGraph()
            
            # Add nodes (agents)
            for agent in agents:
                G.add_node(agent)
            
            # Add edges (connections)
            for conn in connections:
                G.add_edge(conn["from"], conn["to"], timestamp=conn["timestamp"])
            
            # Set up the plot
            plt.figure(figsize=(12, 8))
            pos = nx.spring_layout(G, seed=42)  # For reproducibility
            
            # Draw nodes with custom colors
            node_colors = ["#4CAF50", "#2196F3", "#FFC107", "#E91E63", "#9C27B0"]
            nx.draw_networkx_nodes(G, pos, 
                                  node_size=2000, 
                                  alpha=0.8,
                                  node_color=node_colors[:len(agents)])
            
            # Draw edges
            nx.draw_networkx_edges(G, pos, width=2, alpha=0.7, edge_color="#555555", 
                                  arrowsize=20, arrowstyle='->')
            
            # Draw labels
            nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif')
            
            # Save the figure
            output_path = os.path.join(self.output_dir, "agent_interaction_graph.png")
            plt.title("Agent Interaction Graph", fontsize=16)
            plt.axis('off')
            plt.tight_layout()
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            return output_path
        
        except Exception as e:
            print(f"Error generating agent graph: {str(e)}")
            return ""
    
    def _generate_timeline(self, visualization_data: Dict[str, Any]) -> str:
        """
        Generate a timeline visualization of agent activities
        
        Args:
            visualization_data: The flow visualization data
            
        Returns:
            Path to the generated image file
        """
        try:
            timeline = visualization_data.get("timeline", [])
            
            if not timeline:
                return ""
            
            # Convert timestamps to datetime objects
            for item in timeline:
                item["datetime"] = datetime.datetime.strptime(item["timestamp"], "%Y-%m-%d %H:%M:%S")
            
            # Sort by timestamp
            timeline.sort(key=lambda x: x["datetime"])
            
            # Group by agent
            agents = {}
            for item in timeline:
                if item["agent"] not in agents:
                    agents[item["agent"]] = []
                agents[item["agent"]].append(item)
            
            # Set up the plot
            fig, ax = plt.subplots(figsize=(15, 8))
            
            # Colors for different status values
            status_colors = {
                "in_progress": "#FFC107",  # Yellow
                "success": "#4CAF50",      # Green
                "failed": "#F44336",       # Red
                "retry": "#2196F3",        # Blue
                "pending": "#9E9E9E"       # Gray
            }
            
            # Plot timeline for each agent
            y_positions = {}
            y_labels = []
            
            for i, (agent_name, events) in enumerate(agents.items()):
                y_pos = i
                y_positions[agent_name] = y_pos
                y_labels.append(agent_name)
                
                for event in events:
                    color = status_colors.get(event["status"], "#9C27B0")  # Default to purple
                    ax.scatter(event["datetime"], y_pos, color=color, s=100, alpha=0.7)
                    
                    # Add subtask labels at selected points
                    if event["action"] in ["complete_task", "handle_failure"]:
                        ax.annotate(
                            f"Subtask: {event['subtask_id']}",
                            (event["datetime"], y_pos),
                            xytext=(10, 0),
                            textcoords="offset points",
                            fontsize=8,
                            alpha=0.7
                        )
            
            # Configure the plot
            ax.set_yticks(range(len(y_labels)))
            ax.set_yticklabels(y_labels)
            ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
            plt.xticks(rotation=45)
            
            # Add legend
            legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                              markerfacecolor=color, markersize=10, label=status)
                              for status, color in status_colors.items()]
            ax.legend(handles=legend_elements, loc='upper right')
            
            # Add grid and title
            ax.grid(True, alpha=0.3)
            plt.title("Agent Activity Timeline", fontsize=16)
            plt.tight_layout()
            
            # Save the figure
            output_path = os.path.join(self.output_dir, "execution_timeline.png")
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            return output_path
        
        except Exception as e:
            print(f"Error generating timeline: {str(e)}")
            return ""
    
    def _generate_execution_summary(self) -> str:
        """
        Generate a summary visualization of execution statistics
        
        Returns:
            Path to the generated image file
        """
        try:
            execution_tracking = self.execution_data.get("execution_tracking", {})
            
            if not execution_tracking:
                return ""
            
            # Collect statistics
            status_counts = {"completed": 0, "failed": 0, "in_progress": 0, "pending": 0, "retry": 0}
            attempt_counts = {}
            
            for subtask_id, data in execution_tracking.items():
                status = data.get("status", "unknown")
                attempts = data.get("attempts", 0)
                
                # Update status counts
                if status in status_counts:
                    status_counts[status] += 1
                else:
                    status_counts["pending"] += 1
                
                # Update attempt counts
                if attempts in attempt_counts:
                    attempt_counts[attempts] += 1
                else:
                    attempt_counts[attempts] = 1
            
            # Create a figure with subplots
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
            
            # Plot status distribution
            status_labels = list(status_counts.keys())
            status_values = list(status_counts.values())
            status_colors = ['#4CAF50', '#F44336', '#FFC107', '#9E9E9E', '#2196F3']
            
            ax1.bar(status_labels, status_values, color=status_colors)
            ax1.set_title('Task Status Distribution', fontsize=14)
            ax1.set_ylabel('Number of Tasks')
            ax1.grid(axis='y', alpha=0.3)
            
            # Plot attempt distribution
            attempt_labels = sorted(attempt_counts.keys())
            attempt_values = [attempt_counts[a] for a in attempt_labels]
            
            # Create a colormap from green to red
            cmap = LinearSegmentedColormap.from_list('GreenToRed', ['green', 'yellow', 'red'])
            colors = cmap([(i / max(1, max(attempt_labels))) for i in attempt_labels])
            
            ax2.bar(attempt_labels, attempt_values, color=colors)
            ax2.set_title('Retry Attempts Distribution', fontsize=14)
            ax2.set_xlabel('Number of Attempts')
            ax2.set_ylabel('Number of Tasks')
            ax2.set_xticks(attempt_labels)
            ax2.grid(axis='y', alpha=0.3)
            
            # Add title and adjust layout
            plt.suptitle('Execution Summary', fontsize=16)
            plt.tight_layout(rect=[0, 0, 1, 0.95])
            
            # Save the figure
            output_path = os.path.join(self.output_dir, "execution_summary.png")
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            return output_path
        
        except Exception as e:
            print(f"Error generating execution summary: {str(e)}")
            return ""

def visualize_execution(execution_result: Dict[str, Any]) -> Dict[str, str]:
    """
    Generate visualizations for an execution result
    
    Args:
        execution_result: The execution result from main()
        
    Returns:
        Dictionary with paths to generated visualizations
    """
    visualizer = ExecutionVisualizer(execution_result)
    return visualizer.generate_visualizations()

# Example usage
if __name__ == "__main__":
    # Load example data from a JSON file
    try:
        with open("execution_result.json", "r") as f:
            execution_data = json.load(f)
        
        # Generate visualizations
        visualization_paths = visualize_execution(execution_data)
        
        # Print the paths to generated visualizations
        for vis_type, path in visualization_paths.items():
            print(f"{vis_type}: {path}")
    
    except Exception as e:
        print(f"Error visualizing execution data: {str(e)}")