"""
Spatial Processor - 3D positioning logic

Handles the spatial audio processing that positions sources in 3D space
and converts to the format needed for SHAC encoding.
"""

import numpy as np
import logging
from typing import Tuple, List, Dict, Optional
import math

# Set up logger
logger = logging.getLogger(__name__)


class SpatialProcessor:
    """Handles 3D spatial audio processing."""
    
    def __init__(self, order: int = 3, sample_rate: int = 48000):
        """Initialize spatial processor.
        
        Args:
            order: Ambisonic order (1-7)
            sample_rate: Audio sample rate
        """
        self.order = order
        self.sample_rate = sample_rate
        self.num_channels = (order + 1) ** 2
        
        # Pre-computed spherical harmonic coefficients
        self._precompute_sh_coefficients()
        
    def _precompute_sh_coefficients(self):
        """Precompute spherical harmonic coefficients for efficiency."""
        # Pre-computed spherical harmonic coefficient tables for common positions
        self.sh_coefficients = {}
        
        # Generate coefficients for common positions
        for elevation in range(-90, 91, 10):
            for azimuth in range(0, 360, 10):
                key = (azimuth, elevation)
                self.sh_coefficients[key] = self._calculate_sh_coefficients(
                    math.radians(azimuth), 
                    math.radians(elevation)
                )
                
    def _calculate_sh_coefficients(self, azimuth: float, elevation: float) -> np.ndarray:
        """Calculate spherical harmonic coefficients for a position.
        
        Args:
            azimuth: Azimuth angle in radians
            elevation: Elevation angle in radians
            
        Returns:
            Array of SH coefficients
        """
        coeffs = np.zeros(self.num_channels)
        
        # Simplified SH calculation for demonstration
        # Real implementation would use proper spherical harmonic math
        coeffs[0] = 1.0  # W channel (omnidirectional)
        
        if self.order >= 1:
            # First order channels
            coeffs[1] = np.cos(elevation) * np.sin(azimuth)  # Y
            coeffs[2] = np.sin(elevation)                    # Z  
            coeffs[3] = np.cos(elevation) * np.cos(azimuth)  # X
            
        if self.order >= 2:
            # Second order channels (simplified)
            coeffs[4] = np.cos(elevation) * np.sin(2 * azimuth)
            coeffs[5] = np.sin(elevation) * np.cos(elevation) * np.sin(azimuth)
            coeffs[6] = (3 * np.sin(elevation)**2 - 1) / 2
            coeffs[7] = np.sin(elevation) * np.cos(elevation) * np.cos(azimuth)
            coeffs[8] = np.cos(elevation)**2 * np.cos(2 * azimuth)
            
        # Higher orders would continue this pattern
        # For now, just use simplified approximations
        
        return coeffs
        
    def cartesian_to_spherical(self, x: float, y: float, z: float) -> Tuple[float, float, float]:
        """Convert Cartesian coordinates to spherical.
        
        Args:
            x, y, z: Cartesian coordinates
            
        Returns:
            (distance, azimuth, elevation) in (meters, radians, radians)
        """
        distance = np.sqrt(x*x + y*y + z*z)
        
        if distance == 0:
            return 0, 0, 0
            
        azimuth = np.arctan2(y, x)
        elevation = np.arcsin(z / distance)
        
        return distance, azimuth, elevation
        
    def position_source(self, audio: np.ndarray, position: Tuple[float, float, float]) -> np.ndarray:
        """Position mono audio source in 3D space.
        
        Args:
            audio: Mono audio signal (1D array)
            position: (x, y, z) position in meters
            
        Returns:
            Ambisonic audio (2D array: samples x channels)
        """
        x, y, z = position
        
        # Convert to spherical coordinates
        distance, azimuth, elevation = self.cartesian_to_spherical(x, y, z)
        
        # Get spherical harmonic coefficients
        coeffs = self._get_sh_coefficients(azimuth, elevation)
        
        # Apply distance attenuation
        attenuation = self._calculate_distance_attenuation(distance)
        
        # Handle multi-channel audio intelligently - preserve original intent
        if len(audio.shape) > 1:
            if audio.shape[1] == 1:
                # Already mono in 2D format
                audio = audio.flatten()
            elif audio.shape[1] == 2:
                # Stereo - preserve left channel to avoid forced mixing
                audio = audio[:, 0]
                logger.debug(f"Spatial processing: Using left channel from stereo")
            else:
                # Multi-channel - preserve first channel
                audio = audio[:, 0]
                logger.debug(f"Spatial processing: Using first channel from {audio.shape[1]}-channel audio")
            
        # Create ambisonic output
        ambisonic = np.zeros((len(audio), self.num_channels))
        
        for ch in range(self.num_channels):
            ambisonic[:, ch] = audio * coeffs[ch] * attenuation
            
        return ambisonic
        
    def _get_sh_coefficients(self, azimuth: float, elevation: float) -> np.ndarray:
        """Get spherical harmonic coefficients for a direction.
        
        Uses precomputed table with interpolation for efficiency.
        """
        # Convert to degrees for lookup
        az_deg = int(np.degrees(azimuth))
        el_deg = int(np.degrees(elevation))
        
        # Normalize to lookup table range
        az_deg = az_deg % 360
        el_deg = max(-90, min(90, el_deg))
        
        # Round to nearest 10 degrees (our precomputed resolution)
        az_key = (az_deg // 10) * 10
        el_key = (el_deg // 10) * 10
        
        key = (az_key, el_key)
        
        if key in self.sh_coefficients:
            return self.sh_coefficients[key]
        else:
            # Calculate on demand if not in table
            return self._calculate_sh_coefficients(azimuth, elevation)
            
    def _calculate_distance_attenuation(self, distance: float) -> float:
        """Calculate distance attenuation factor.
        
        Args:
            distance: Distance in meters
            
        Returns:
            Attenuation factor (0-1)
        """
        if distance <= 0:
            return 1.0
            
        # Simple inverse distance law with minimum distance
        min_distance = 0.1  # 10cm minimum
        effective_distance = max(distance, min_distance)
        
        # 1/r attenuation with some limits
        attenuation = 1.0 / effective_distance
        
        # Clamp to reasonable range
        return max(0.01, min(1.0, attenuation))
        
    def mix_sources(self, source_audio_list: List[Tuple[np.ndarray, Tuple[float, float, float]]]) -> np.ndarray:
        """Mix multiple positioned sources into ambisonic output.
        
        Args:
            source_audio_list: List of (audio, position) tuples
            
        Returns:
            Mixed ambisonic audio
        """
        if not source_audio_list:
            return np.zeros((1024, self.num_channels))  # Empty buffer
            
        # Find maximum length
        max_length = max(len(audio) for audio, _ in source_audio_list)
        
        # Initialize output
        mixed_ambisonic = np.zeros((max_length, self.num_channels))
        
        # Process each source
        for audio, position in source_audio_list:
            # Position the source
            positioned = self.position_source(audio, position)
            
            # Pad if necessary
            if len(positioned) < max_length:
                padded = np.zeros((max_length, self.num_channels))
                padded[:len(positioned)] = positioned
                positioned = padded
                
            # Mix into output
            mixed_ambisonic += positioned
            
        return mixed_ambisonic
        
    def set_order(self, order: int):
        """Change ambisonic order."""
        if order != self.order and 1 <= order <= 7:
            self.order = order
            self.num_channels = (order + 1) ** 2
            self._precompute_sh_coefficients()
            
    def get_channel_count(self) -> int:
        """Get number of ambisonic channels."""
        return self.num_channels
        
    def validate_position(self, position: Tuple[float, float, float]) -> Tuple[float, float, float]:
        """Validate and clamp position to reasonable range.
        
        Args:
            position: (x, y, z) position
            
        Returns:
            Validated position
        """
        x, y, z = position
        
        # Clamp to reasonable range (±100 meters)
        max_distance = 100.0
        x = max(-max_distance, min(max_distance, x))
        y = max(-max_distance, min(max_distance, y))
        z = max(-max_distance, min(max_distance, z))
        
        return (x, y, z)