"""
SHAC Studio Transparent Audio Engine

Revolutionary principle: PRESERVE EVERYTHING.
SHAC is about WHERE sound comes from, not WHAT the sound is.

No resampling. No limiting. No channel conversion. No molding.
Import exactly as recorded. Position in 3D space. That's it.
"""

import numpy as np
import logging
import time
from typing import Dict, List, Tuple, Optional, Callable
from pathlib import Path

# Audio loading with fallbacks
try:
    import soundfile as sf
    SOUNDFILE_AVAILABLE = True
except ImportError:
    SOUNDFILE_AVAILABLE = False

try:
    import librosa
    LIBROSA_AVAILABLE = True
except ImportError:
    LIBROSA_AVAILABLE = False

# Audio output for preview
try:
    import sounddevice as sd
    SOUNDDEVICE_AVAILABLE = True
    logging.info(f"sounddevice v{sd.__version__} imported successfully for preview playback")
except ImportError as e:
    SOUNDDEVICE_AVAILABLE = False
    logging.warning(f"sounddevice not available - preview playback disabled: {e}")
except Exception as e:
    SOUNDDEVICE_AVAILABLE = False
    logging.error(f"Unexpected error importing sounddevice: {e}", exc_info=True)

# State management
from .state_manager import StatefulComponent, StateEvent, state_manager

# Set up logger
logger = logging.getLogger(__name__)


class TransparentAudioEngine(StatefulComponent):
    """
    Transparent audio engine that preserves original audio quality completely.
    
    Core principle: Import exactly as recorded, position in 3D space, output exactly as intended.
    No molding, no processing, no "improvements" - just transparent spatial positioning.
    """
    
    def __init__(self):
        super().__init__("TransparentAudioEngine")
        
        # Audio sources - preserved in original format
        self.sources = {}  # source_id: AudioSource
        
        # Session info - adapts to content, doesn't force changes
        self.session_sample_rate = None  # Determined by first loaded file
        self.session_bit_depth = None    # Preserved from source
        self.session_channels = None     # Preserved from source
        
        # Spatial positioning only - no audio molding
        self.listener_position = np.array([0.0, 0.0, 0.0])
        self.listener_rotation = np.array([0.0, 0.0])  # azimuth, elevation
        
        # Audio callback for real-time playback
        self.position_callback = None
        
        # Source preview playback
        self.preview_source_id = None
        self.preview_position = 0.0
        self.preview_playing = False
        self.preview_stream = None  # sounddevice output stream for preview
        self.preview_audio_available = SOUNDDEVICE_AVAILABLE

        # Quality preservation metrics
        self.import_stats = []

        logger.info("Transparent Audio Engine initialized - preserving original quality completely")
        if not self.preview_audio_available:
            logger.warning("Preview audio playback disabled - sounddevice not available")
        
        
    def add_source(self, source_id: str, audio_data: Dict, position: Tuple[float, float, float] = (0, 0, 0)) -> bool:
        """
        Add audio source preserving original quality completely.
        Only stores position info - no audio processing whatsoever.
        """
        try:
            # Create source preserving everything
            source = AudioSource(
                source_id=source_id,
                audio_data=audio_data['audio_data'],  # Exactly as loaded
                sample_rate=audio_data['sample_rate'],  # Original rate
                channels=audio_data['channels'],  # Original channels
                bit_depth=audio_data['bit_depth'],  # Original bit depth
                position=np.array(position, dtype=np.float64),
                volume=1.0,  # Unity gain - no molding
                format_info=audio_data
            )
            
            self.sources[source_id] = source
            
            # Emit state change for UI updates
            self.emit_state_change(StateEvent.SOURCE_ADDED, source_id, {
                'name': source_id,
                'duration': source.duration,
                'sample_rate': source.sample_rate,
                'channels': source.channels,
                'position': position
            })
            
            logger.info(f"Added source '{source_id}' preserving {source.sample_rate}Hz quality")
            return True

        except Exception as e:
            logger.error(f"Failed to add source {source_id}: {e}", exc_info=True)
            return False
            
    def set_source_position(self, source_id: str, position: Tuple[float, float, float]):
        """
        Set source position - pure spatial math, no audio processing.
        """
        if source_id in self.sources:
            self.sources[source_id].position = np.array(position, dtype=np.float64)
            
            # Emit state change
            self.emit_state_change(StateEvent.SOURCE_POSITION_CHANGED, source_id, {
                'position': position
            })
            
            logger.debug(f"Positioned '{source_id}' at {position}")
        else:
            logger.warning(f"Source '{source_id}' not found for positioning")
            
    def set_source_volume(self, source_id: str, volume: float):
        """
        Set source volume - simple gain, no limiting or processing.
        """
        if source_id in self.sources:
            # Simple volume - no processing
            self.sources[source_id].volume = max(0.0, volume)
            logger.debug(f"Volume '{source_id}': {volume:.2f}")
        else:
            logger.warning(f"Source '{source_id}' not found for volume")
            
    def mute_source(self, source_id: str):
        """Mute a specific source."""
        if source_id in self.sources:
            self.sources[source_id].is_muted = True
            logger.debug(f"Muted '{source_id}'")
            
            # Emit state change
            self.emit_state_change(StateEvent.SOURCE_MUTED, source_id, {
                'muted': True
            })
        else:
            logger.warning(f"Source '{source_id}' not found for muting")
            
    def unmute_source(self, source_id: str):
        """Unmute a specific source."""
        if source_id in self.sources:
            self.sources[source_id].is_muted = False
            logger.debug(f"Unmuted '{source_id}'")
            
            # Emit state change
            self.emit_state_change(StateEvent.SOURCE_UNMUTED, source_id, {
                'muted': False
            })
        else:
            logger.warning(f"Source '{source_id}' not found for unmuting")
            
    def set_source_mute(self, source_id: str, muted: bool):
        """Set mute state for a source (unified interface)."""
        if muted:
            self.mute_source(source_id)
        else:
            self.unmute_source(source_id)
            
    def is_source_muted(self, source_id: str) -> bool:
        """Check if a source is muted."""
        if source_id in self.sources:
            return self.sources[source_id].is_muted
        return False
            
    def solo_source(self, source_id: str):
        """Solo a specific source (mute all others)."""
        if source_id in self.sources:
            # Mute all sources except the selected one
            for sid, source in self.sources.items():
                source.is_muted = (sid != source_id)
            
            logger.debug(f"Soloed '{source_id}' (muted {len(self.sources)-1} others)")
            
            # Emit state change for all sources
            for sid in self.sources.keys():
                if sid == source_id:
                    self.emit_state_change(StateEvent.SOURCE_UNMUTED, sid, {'muted': False})
                else:
                    self.emit_state_change(StateEvent.SOURCE_MUTED, sid, {'muted': True})
        else:
            logger.warning(f"Source '{source_id}' not found for soloing")
            
    def unsolo_all_sources(self):
        """Unmute all sources (clear solo)."""
        for sid, source in self.sources.items():
            source.is_muted = False
            self.emit_state_change(StateEvent.SOURCE_UNMUTED, sid, {'muted': False})

        logger.debug(f"Unsoloed all sources")

    def remove_source(self, source_id: str) -> bool:
        """Remove a source from the audio engine."""
        if source_id in self.sources:
            del self.sources[source_id]
            logger.info(f"Removed source '{source_id}'")

            # Emit state change
            self.emit_state_change(StateEvent.SOURCE_REMOVED, source_id, {})
            return True
        else:
            logger.warning(f"Source '{source_id}' not found for removal")
            return False

    def clear_sources(self):
        """Remove all sources from the audio engine."""
        source_ids = list(self.sources.keys())
        for source_id in source_ids:
            self.remove_source(source_id)
        logger.info("Cleared all sources")

    # Source preview methods
    
    def start_source_preview(self, source_id: str, start_position: float = 0.0):
        """Start previewing a single source with actual audio output."""
        if source_id not in self.sources:
            return False

        if not self.preview_audio_available:
            logger.warning("Cannot start preview - sounddevice not available")
            return False

        # Set preview state
        self.preview_source_id = source_id
        self.preview_position = start_position
        self.preview_playing = True

        # Set source playback position
        source = self.sources[source_id]
        source.set_position(start_position)

        # Start audio output stream if not already running
        if self.preview_stream is None or not self.preview_stream.active:
            try:
                self.preview_stream = sd.OutputStream(
                    samplerate=source.sample_rate,
                    channels=2,
                    callback=self._preview_audio_callback,
                    blocksize=1024,
                    dtype='float32'
                )
                self.preview_stream.start()
                logger.info(f"Started preview audio stream: {source_id} at {start_position:.2f}s")
            except Exception as e:
                logger.error(f"Failed to start preview audio stream: {e}")
                self.preview_playing = False
                return False

        logger.debug(f"Started preview: {source_id} at {start_position:.2f}s")
        return True
        
    def pause_source_preview(self):
        """Pause source preview (keeps stream open, stops audio generation)."""
        self.preview_playing = False
        logger.debug("Paused source preview")
        
    def stop_source_preview(self):
        """Stop source preview and reset."""
        # Stop audio stream
        if self.preview_stream is not None:
            try:
                self.preview_stream.stop()
                self.preview_stream.close()
                self.preview_stream = None
                logger.debug("Closed preview audio stream")
            except Exception as e:
                logger.error(f"Error closing preview stream: {e}")

        # Reset preview state
        if self.preview_source_id and self.preview_source_id in self.sources:
            source = self.sources[self.preview_source_id]
            source.set_position(0.0)

        self.preview_source_id = None
        self.preview_position = 0.0
        self.preview_playing = False
        logger.debug("Stopped source preview")
        
    def seek_source_preview(self, position: float):
        """Seek to position in source preview."""
        if self.preview_source_id and self.preview_source_id in self.sources:
            self.preview_position = position
            source = self.sources[self.preview_source_id]
            source.set_position(position)
            logger.debug(f"Seek preview to {position:.2f}s")

    def _preview_audio_callback(self, outdata, frames, time_info, status):
        """Sounddevice callback for preview audio output."""
        if status:
            logger.warning(f"Preview audio callback status: {status}")

        # Get audio from the preview sample method
        audio_data = self.get_source_preview_sample(frames)

        # Copy to output buffer
        outdata[:] = audio_data

    def get_source_preview_sample(self, frame_count: int = 512) -> np.ndarray:
        """Get audio sample for source preview (single source only)."""
        if (not self.preview_playing or 
            not self.preview_source_id or 
            self.preview_source_id not in self.sources):
            return np.zeros((frame_count, 2), dtype=np.float32)
            
        source = self.sources[self.preview_source_id]
        
        if source.current_position >= len(source.audio_data):
            return np.zeros((frame_count, 2), dtype=np.float32)
            
        # Get source audio frames
        end_pos = min(source.current_position + frame_count, len(source.audio_data))
        source_frames = end_pos - source.current_position
        
        if source_frames <= 0:
            return np.zeros((frame_count, 2), dtype=np.float32)
            
        # Extract audio data
        audio_segment = source.audio_data[source.current_position:end_pos]
        
        # Handle mono/stereo
        if audio_segment.ndim == 1:
            # Mono: duplicate to stereo
            stereo_audio = np.column_stack([audio_segment, audio_segment])
        else:
            # Already stereo
            stereo_audio = audio_segment
            
        # Pad if needed
        if stereo_audio.shape[0] < frame_count:
            padding = np.zeros((frame_count - stereo_audio.shape[0], 2), dtype=stereo_audio.dtype)
            stereo_audio = np.vstack([stereo_audio, padding])
            
        # Update position
        source.current_position = end_pos
        
        # Apply volume
        output = stereo_audio * source.volume
        
        return output.astype(np.float32)
            
    def is_source_playing(self, source_id: str) -> bool:
        """Check if a source is currently playing (not muted)."""
        if source_id in self.sources:
            return self.sources[source_id].is_playing
        return False
            
    def get_spatial_audio_sample(self, frame_count: int = 512) -> np.ndarray:
        """
        Get spatial audio preserving original quality.
        Only applies minimal spatial positioning math - no audio molding.
        """
        if not self.sources:
            # Silence when no sources
            return np.zeros((frame_count, 2), dtype=np.float32)
            
        # Start with silence in session format
        if self.session_sample_rate is None:
            return np.zeros((frame_count, 2), dtype=np.float32)
            
        # Mix all sources with spatial positioning only
        output = np.zeros((frame_count, 2), dtype=np.float64)
        
        for source in self.sources.values():
            if (hasattr(source, 'current_position') and 
                source.current_position < len(source.audio_data) and
                source.is_playing and not source.is_muted):
                # Get source audio exactly as recorded
                end_pos = min(source.current_position + frame_count, len(source.audio_data))
                source_frames = end_pos - source.current_position
                
                if source_frames > 0:
                    # Extract exactly as recorded - no processing
                    source_audio = source.audio_data[source.current_position:end_pos]
                    
                    # Ensure correct shape for mixing
                    if source_audio.shape[1] == 1:
                        # Mono source - simple placement
                        source_audio = np.concatenate([source_audio, source_audio], axis=1)
                    elif source_audio.shape[1] > 2:
                        # Multi-channel - preserve stereo info
                        source_audio = source_audio[:, :2]
                        
                    # Apply only minimal spatial positioning
                    positioned_audio = self._apply_spatial_positioning(source_audio, source.position)
                    
                    # Simple volume (no limiting)
                    positioned_audio *= source.volume
                    
                    # Mix into output
                    mix_frames = min(positioned_audio.shape[0], output.shape[0])
                    output[:mix_frames] += positioned_audio[:mix_frames]
                    
                    # Update playback position
                    source.current_position = end_pos
                    
        # Convert back to original precision
        return output.astype(np.float32)
        
    def _apply_spatial_positioning(self, audio: np.ndarray, position: np.ndarray) -> np.ndarray:
        """
        Apply MINIMAL spatial positioning - just basic panning math.
        No processing, no molding, no limiting - just position in space.
        """
        x, y, z = position
        
        # Simple distance attenuation (no complex processing)
        distance = np.sqrt(x*x + y*y + z*z)
        distance_factor = 1.0 / (1.0 + distance * 0.1)  # Simple rolloff
        
        # Simple stereo panning based on X position
        if abs(x) > 0.1:
            pan = np.tanh(x * 0.5)  # Smooth panning curve
            left_gain = np.sqrt((1.0 - pan) * 0.5) if pan > 0 else 1.0
            right_gain = np.sqrt((1.0 + pan) * 0.5) if pan < 0 else 1.0
        else:
            left_gain = right_gain = 1.0
            
        # Apply simple positioning
        positioned = audio.copy()
        positioned[:, 0] *= left_gain * distance_factor    # Left channel
        positioned[:, 1] *= right_gain * distance_factor   # Right channel
        
        return positioned
        
    def set_position_callback(self, callback: Callable):
        """Set position update callback for transport sync."""
        self.position_callback = callback
        
    def get_import_quality_report(self) -> List[Dict]:
        """Get report of how well we preserved original audio quality."""
        return self.import_stats.copy()
        
    def play(self):
        """Start audio playback - backward compatibility."""
        for source in self.sources.values():
            source.is_playing = True
        logger.info("Audio playing - quality preserved")
        
    def pause(self):
        """Pause audio playback - backward compatibility."""
        for source in self.sources.values():
            source.is_playing = False
        logger.info("Audio paused - quality preserved")
        
    def stop(self):
        """Stop audio playback - backward compatibility."""
        # Reset all source positions
        for source in self.sources.values():
            source.reset_playback()
            source.is_playing = False
        logger.info("Audio stopped - quality preserved")
        
    def get_duration(self) -> float:
        """Get total duration of all sources - backward compatibility."""
        if not self.sources:
            return 0.0
        # Return the longest source duration
        return max(source.duration for source in self.sources.values())
        
    def cleanup(self):
        """Clean shutdown preserving any cached quality."""
        logger.info("Transparent Audio Engine cleanup complete")


class AudioSource:
    """
    Audio source that preserves original quality completely.
    No processing, no conversion - exactly as recorded.
    """
    
    def __init__(self, source_id: str, audio_data: np.ndarray, sample_rate: int, 
                 channels: int, bit_depth: str, position: np.ndarray, volume: float = 1.0,
                 format_info: Dict = None):
        self.source_id = source_id
        self.audio_data = audio_data  # Exactly as loaded
        self.sample_rate = sample_rate  # Original rate
        self.channels = channels  # Original channels  
        self.bit_depth = bit_depth  # Original bit depth
        self.position = position
        self.volume = volume
        self.format_info = format_info or {}
        
        # Playback state
        self.current_position = 0
        self.is_playing = False
        self.is_muted = False  # Separate mute state
        
        # Derived properties
        self.duration = len(audio_data) / sample_rate
        self.frames = len(audio_data)
        
    def reset_playback(self):
        """Reset to beginning - no processing."""
        self.current_position = 0
        
    def set_position(self, seconds: float):
        """Seek to time position - sample accurate."""
        self.current_position = int(seconds * self.sample_rate)
        self.current_position = max(0, min(self.current_position, len(self.audio_data)))


# Create the transparent engine instance
AudioEngine = TransparentAudioEngine