"""
Native Python Spatial Audio Player
Based on the JavaScript spatial audio engine but implemented natively in Python
"""

import numpy as np
import threading
import time
import logging
from typing import Dict, List, Tuple, Optional
import math

# Use sounddevice instead of PyAudio (easier to install, better cross-platform support)
try:
    import sounddevice as sd
    AUDIO_AVAILABLE = True
    logging.info(f"sounddevice v{sd.__version__} imported successfully for spatial playback")
except ImportError as e:
    AUDIO_AVAILABLE = False
    logging.warning(f"sounddevice not available - audio playback disabled: {e}")
except Exception as e:
    AUDIO_AVAILABLE = False
    logging.error(f"Unexpected error importing sounddevice: {e}", exc_info=True)

class NativeSpatialAudioEngine:
    """Native Python implementation of spatial audio engine with HRTF and movement."""
    
    def __init__(self, sample_rate=48000, buffer_size=1024):
        self.sample_rate = sample_rate
        self.buffer_size = buffer_size
        
        # Audio layers (sources)
        self.layers = {}
        self.is_playing = False
        self.current_time = 0.0
        self.duration = 0.0
        
        # Listener position and orientation
        self.listener_position = np.array([0.0, 0.0, 0.0])
        self.listener_rotation = {'azimuth': 0.0, 'elevation': 0.0, 'roll': 0.0}
        
        # Movement parameters
        self.move_speed = 0.17
        self.rotate_speed = 1.7
        
        # Audio output
        self.audio_available = AUDIO_AVAILABLE
        self.output_stream = None
        self.playback_thread = None
        self.stop_playback = False
        self.audio_buffer = []

        # Initialize audio output if available
        if not self.audio_available:
            logging.warning("Audio output not available - playback will be disabled")
        else:
            self.setup_audio_output()
        
    def setup_audio_output(self):
        """Set up sounddevice for real-time output."""
        if not self.audio_available:
            return

        try:
            # Get default output device info
            device_info = sd.query_devices(kind='output')
            logging.info(f"Audio device found: {device_info['name']}")

            # Use device's native sample rate if our requested rate doesn't match
            device_sample_rate = int(device_info['default_samplerate'])
            if self.sample_rate != device_sample_rate:
                logging.info(f"Adjusting sample rate from {self.sample_rate}Hz to device native {device_sample_rate}Hz")
                self.sample_rate = device_sample_rate

            logging.info(f"Audio configured: {self.sample_rate}Hz, {self.buffer_size} frames buffer")

            # Test if we can actually open a stream
            test_stream = sd.OutputStream(
                samplerate=self.sample_rate,
                channels=2,
                blocksize=self.buffer_size,
                dtype='float32'
            )
            test_stream.close()
            logging.info("Audio output verified and ready")

        except Exception as e:
            logging.error(f"Failed to initialize audio output: {e}", exc_info=True)
            self.audio_available = False
            logging.warning("Audio playback will be disabled - continuing without sound")
    
    def load_sources(self, sources_data):
        """Load audio sources with spatial positions."""
        self.layers.clear()
        max_duration = 0.0
        
        for source in sources_data:
            layer_name = source['name']
            audio_data = source['audio_data']
            position = np.array(source['position'])
            
            # Handle multi-channel audio without forced mixing
            if len(audio_data.shape) > 1:
                if audio_data.shape[1] == 1:
                    audio_data = audio_data.flatten()
                elif audio_data.shape[1] == 2:
                    # Preserve left channel from stereo
                    audio_data = audio_data[:, 0]
                else:
                    # Preserve first channel from multi-channel
                    audio_data = audio_data[:, 0]
            
            # Calculate duration
            duration = len(audio_data) / self.sample_rate
            max_duration = max(max_duration, duration)
            
            # Preserve original format - no forced float32 conversion
            self.layers[layer_name] = {
                'audio_data': audio_data,
                'position': position,
                'duration': duration,
                'current_sample': 0
            }
            
            logging.info(f"Loaded source '{layer_name}' at position {position}, duration {duration:.2f}s")
        
        self.duration = max_duration
        logging.info(f"Total composition duration: {self.duration:.2f}s")
        return len(self.layers)
    
    def play(self):
        """Start spatial audio playback."""
        if self.is_playing or not self.layers or not self.audio_available:
            if not self.audio_available:
                logging.warning("Cannot play: audio output not available")
            return

        try:
            # Open audio stream with sounddevice
            self.output_stream = sd.OutputStream(
                samplerate=self.sample_rate,
                channels=2,  # Stereo output
                blocksize=self.buffer_size,
                dtype='float32',
                callback=self._audio_callback
            )

            self.is_playing = True
            self.stop_playback = False
            self.output_stream.start()

            logging.info("Started spatial audio playback")

        except Exception as e:
            logging.error(f"Failed to start playback: {e}")
            self.audio_available = False
    
    def stop(self):
        """Stop spatial audio playback."""
        if not self.is_playing:
            return
            
        self.stop_playback = True
        self.is_playing = False

        if self.output_stream:
            self.output_stream.stop()
            self.output_stream.close()
            self.output_stream = None
            
        # Reset playback position
        self.current_time = 0.0
        for layer in self.layers.values():
            layer['current_sample'] = 0
            
        logging.info("Stopped spatial audio playback")
    
    def pause(self):
        """Pause spatial audio playback."""
        if self.is_playing and self.output_stream:
            self.output_stream.stop()
            self.is_playing = False
            logging.info("Paused spatial audio playback")
    
    def _audio_callback(self, outdata, frames, time_info, status):
        """Real-time audio callback for spatial processing (sounddevice API)."""
        try:
            if status:
                logging.warning(f"Audio callback status: {status}")

            # Create output buffer (stereo)
            output = np.zeros((frames, 2), dtype=np.float32)

            if self.stop_playback:
                outdata[:] = output
                return

            # Process each source
            for layer_name, layer in self.layers.items():
                audio_data = layer['audio_data']
                position = layer['position']
                current_sample = layer['current_sample']

                # Check if we have audio to play
                if current_sample >= len(audio_data):
                    continue

                # Get audio chunk
                end_sample = min(current_sample + frames, len(audio_data))
                chunk_size = end_sample - current_sample

                if chunk_size <= 0:
                    continue

                audio_chunk = audio_data[current_sample:end_sample]

                # Apply spatial processing
                spatial_chunk = self._apply_spatial_processing(audio_chunk, position)

                # Add to output (mix sources)
                if chunk_size < frames:
                    # Pad with zeros if needed
                    padded_chunk = np.zeros((frames, 2), dtype=np.float32)
                    padded_chunk[:chunk_size] = spatial_chunk
                    spatial_chunk = padded_chunk

                output += spatial_chunk

                # Update sample position
                layer['current_sample'] = end_sample

            # Update current time
            self.current_time += frames / self.sample_rate

            # Apply limiter to prevent clipping
            output = self._apply_limiter(output)

            # Copy to output buffer (sounddevice requires in-place modification)
            outdata[:] = output

        except Exception as e:
            logging.error(f"Audio callback error: {e}", exc_info=True)
            outdata[:] = np.zeros((frames, 2), dtype=np.float32)
    
    def _apply_spatial_processing(self, mono_audio, source_position):
        """Apply 3D spatial processing using HRTF."""
        # Calculate relative position (source relative to listener)
        relative_pos = source_position - self.listener_position
        
        # Convert to spherical coordinates
        distance = np.linalg.norm(relative_pos)
        if distance < 0.001:  # Avoid division by zero
            distance = 0.001
            
        azimuth = math.atan2(relative_pos[0], relative_pos[2])
        elevation = math.asin(relative_pos[1] / distance)
        
        # Apply listener rotation
        azimuth -= math.radians(self.listener_rotation['azimuth'])
        elevation -= math.radians(self.listener_rotation['elevation'])
        
        # Simple HRTF implementation
        hrtf = self._get_hrtf(azimuth, elevation, distance)
        
        # Apply HRTF to create stereo output
        stereo_output = np.zeros((len(mono_audio), 2), dtype=np.float32)
        stereo_output[:, 0] = mono_audio * hrtf['left']    # Left channel
        stereo_output[:, 1] = mono_audio * hrtf['right']   # Right channel
        
        return stereo_output
    
    def _get_hrtf(self, azimuth, elevation, distance):
        """Calculate Head-Related Transfer Function for spatial positioning."""
        # Distance attenuation
        distance_gain = 1.0 / max(distance, 1.0)
        
        # Interaural Time Difference (ITD)
        head_radius = 0.0875  # 8.75cm average head radius
        sound_speed = 343     # m/s
        itd = (head_radius / sound_speed) * (azimuth + math.sin(azimuth))
        
        # Interaural Level Difference (ILD)
        base_level_diff = math.sin(azimuth) * 12  # Up to 12dB difference
        
        # Head shadow effect
        head_shadow_left = 1.0 - 0.3 * max(0, math.sin(azimuth))
        head_shadow_right = 1.0 - 0.3 * max(0, -math.sin(azimuth))
        
        # Elevation-dependent pinna filtering
        pinna_gain = self._get_pinna_response(elevation)
        
        # Calculate final gains
        left_gain = (0.5 + 0.5 * math.cos(azimuth + math.pi/2)) * \
                   pinna_gain * head_shadow_left * distance_gain
        right_gain = (0.5 + 0.5 * math.cos(azimuth - math.pi/2)) * \
                    pinna_gain * head_shadow_right * distance_gain
        
        return {
            'left': max(0.01, left_gain),
            'right': max(0.01, right_gain),
            'itd': itd
        }
    
    def _get_pinna_response(self, elevation):
        """Calculate pinna (outer ear) filtering based on elevation."""
        elevation_norm = elevation / (math.pi / 2)  # Normalize to [-1, 1]
        
        if elevation > 0:
            # Above horizon: boost high frequencies
            return 1.0 + 0.2 * elevation_norm
        else:
            # Below horizon: maintain low frequencies
            return 1.0 + 0.1 * abs(elevation_norm)
    
    def _apply_limiter(self, audio):
        """Apply soft limiting to prevent clipping."""
        # Simple soft limiter
        threshold = 0.9
        ratio = 0.1
        
        # Calculate peak levels
        peak = np.max(np.abs(audio))
        
        if peak > threshold:
            # Apply compression
            gain_reduction = threshold + (peak - threshold) * ratio
            gain = gain_reduction / peak
            audio *= gain
        
        # Hard limit at -0.1dB
        audio = np.clip(audio, -0.9, 0.9)
        return audio
    
    def update_listener_position(self, x, y, z):
        """Update listener position in 3D space."""
        self.listener_position = np.array([x, y, z])
        logging.debug(f"Listener position: {self.listener_position}")
    
    def update_listener_rotation(self, azimuth, elevation, roll=0):
        """Update listener rotation."""
        self.listener_rotation = {
            'azimuth': azimuth,
            'elevation': elevation, 
            'roll': roll
        }
        logging.debug(f"Listener rotation: {self.listener_rotation}")
    
    def move_listener(self, dx, dy, dz):
        """Move listener relative to current position."""
        self.listener_position += np.array([dx, dy, dz]) * self.move_speed
    
    def get_current_time(self):
        """Get current playback time."""
        return self.current_time
    
    def get_duration(self):
        """Get total duration."""
        return self.duration
    
    def get_listener_position(self):
        """Get current listener position."""
        return self.listener_position.tolist()
    
    def get_listener_rotation(self):
        """Get current listener rotation."""
        return self.listener_rotation
    
    def get_layer_info(self):
        """Get information about loaded layers."""
        return [{
            'name': name,
            'position': layer['position'].tolist(),
            'duration': layer['duration']
        } for name, layer in self.layers.items()]
    
    def cleanup(self):
        """Clean up audio resources."""
        self.stop()
        # sounddevice doesn't need explicit cleanup like PyAudio
        logging.info("Audio resources cleaned up")