"""
Advanced Sound Engine for SHAC Studio
Created by Claude to support sophisticated spatial audio composition

This engine provides high-level audio manipulation tools that make
creating complex spatial compositions more intuitive.
"""

import numpy as np
from typing import List, Dict, Optional, Tuple, Union
import logging

logger = logging.getLogger(__name__)


class SoundEngine:
    """
    Advanced sound engine with high-level audio manipulation capabilities.
    
    This class provides intuitive methods for:
    - Track creation and management
    - Multi-track mixing
    - Note-to-frequency conversion
    - Audio normalization
    - Time-based track construction
    """
    
    def __init__(self, sample_rate: int = 48000):
        self.sample_rate = sample_rate
        
        # Note frequency mapping (A4 = 440Hz)
        self.note_frequencies = self._build_note_frequency_map()
        
    def _build_note_frequency_map(self) -> Dict[str, float]:
        """Build a dictionary mapping note names to frequencies."""
        notes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
        freq_map = {}
        
        # A4 = 440Hz as reference
        a4_midi = 69
        
        for octave in range(0, 9):  # C0 to C8
            for i, note in enumerate(notes):
                midi_num = (octave + 1) * 12 + i
                # Frequency formula: f = 440 * 2^((midi_num - 69) / 12)
                frequency = 440.0 * (2.0 ** ((midi_num - a4_midi) / 12.0))
                
                note_name = f"{note}{octave}"
                freq_map[note_name] = frequency
                
                # Also support flat notation
                if '#' in note and i < len(notes) - 1:
                    flat_note = notes[i + 1] + 'b'
                    freq_map[f"{flat_note}{octave}"] = frequency
        
        return freq_map
    
    def note_to_frequency(self, note: str) -> float:
        """
        Convert a note name to frequency.
        
        Args:
            note: Note name (e.g., 'C4', 'A#3', 'Bb2')
            
        Returns:
            Frequency in Hz
        """
        if note in self.note_frequencies:
            return self.note_frequencies[note]
        else:
            logger.warning(f"Note '{note}' not found, defaulting to A4 (440Hz)")
            return 440.0
    
    def note_to_midi(self, note: str) -> int:
        """
        Convert a note name to MIDI number.
        
        Args:
            note: Note name (e.g., 'C4', 'A#3', 'Bb2')
            
        Returns:
            MIDI note number (0-127, where C4 = 60)
        """
        # Parse note name
        note = note.upper()
        
        # Extract octave
        octave = int(note[-1])
        note_name = note[:-1]
        
        # Note to semitone mapping
        note_map = {
            'C': 0, 'C#': 1, 'DB': 1, 'D': 2, 'D#': 3, 'EB': 3,
            'E': 4, 'F': 5, 'F#': 6, 'GB': 6, 'G': 7, 'G#': 8,
            'AB': 8, 'A': 9, 'A#': 10, 'BB': 10, 'B': 11
        }
        
        if note_name in note_map:
            # MIDI note = (octave + 1) * 12 + semitone
            # C4 = 60, so octave 4 corresponds to MIDI octave 4
            midi_note = (octave + 1) * 12 + note_map[note_name]
            return max(0, min(127, midi_note))  # Clamp to valid MIDI range
        else:
            logger.warning(f"Note '{note}' not recognized, returning 60 (C4)")
            return 60
    
    def create_track(self, duration: float) -> np.ndarray:
        """
        Create an empty audio track of specified duration.
        
        Args:
            duration: Duration in seconds
            
        Returns:
            Empty numpy array of correct length
        """
        num_samples = int(duration * self.sample_rate)
        return np.zeros(num_samples, dtype=np.float32)
    
    def add_to_track(self, track: np.ndarray, sound: np.ndarray, 
                     start_time: float, fade_in: float = 0.0, fade_out: float = 0.0) -> None:
        """
        Add a sound to a track at a specific time.
        
        Args:
            track: Target track array
            sound: Sound to add
            start_time: Start time in seconds
            fade_in: Fade in duration in seconds
            fade_out: Fade out duration in seconds
        """
        start_sample = int(start_time * self.sample_rate)
        
        # Apply fades if requested
        if fade_in > 0 or fade_out > 0:
            sound = self.apply_fades(sound, fade_in, fade_out)
        
        # Calculate how many samples we can actually add
        available_samples = len(track) - start_sample
        sound_samples = len(sound)
        
        if available_samples <= 0:
            return  # Start time is beyond track end
        
        samples_to_add = min(sound_samples, available_samples)
        
        # Add the sound to the track
        track[start_sample:start_sample + samples_to_add] += sound[:samples_to_add]
    
    def apply_fades(self, sound: np.ndarray, fade_in: float = 0.0, 
                    fade_out: float = 0.0) -> np.ndarray:
        """
        Apply fade in/out to a sound.
        
        Args:
            sound: Input sound
            fade_in: Fade in duration in seconds
            fade_out: Fade out duration in seconds
            
        Returns:
            Sound with fades applied
        """
        output = sound.copy()
        
        if fade_in > 0:
            fade_in_samples = int(fade_in * self.sample_rate)
            fade_in_samples = min(fade_in_samples, len(sound) // 2)
            
            fade_in_curve = np.linspace(0, 1, fade_in_samples)
            output[:fade_in_samples] *= fade_in_curve
        
        if fade_out > 0:
            fade_out_samples = int(fade_out * self.sample_rate)
            fade_out_samples = min(fade_out_samples, len(sound) // 2)
            
            fade_out_curve = np.linspace(1, 0, fade_out_samples)
            output[-fade_out_samples:] *= fade_out_curve
        
        return output
    
    def mix_tracks(self, tracks: List[np.ndarray], 
                   levels: Optional[List[float]] = None) -> np.ndarray:
        """
        Mix multiple tracks together.
        
        Args:
            tracks: List of track arrays
            levels: Optional list of level adjustments (0-1) for each track
            
        Returns:
            Mixed track
        """
        if not tracks:
            raise ValueError("No tracks to mix")
        
        # Find the longest track
        max_length = max(len(track) for track in tracks)
        
        # Create output array
        mixed = np.zeros(max_length, dtype=np.float32)
        
        # Mix tracks
        for i, track in enumerate(tracks):
            level = 1.0 if levels is None else levels[i]
            mixed[:len(track)] += track * level
        
        return mixed
    
    def normalize(self, audio: np.ndarray, target_level: float = 0.95) -> np.ndarray:
        """
        Normalize audio to a target level.
        
        Args:
            audio: Input audio
            target_level: Target peak level (0-1)
            
        Returns:
            Normalized audio
        """
        if len(audio) == 0:
            return audio
        
        # Find current peak
        current_peak = np.max(np.abs(audio))
        
        if current_peak == 0:
            return audio
        
        # Calculate scaling factor
        scale = target_level / current_peak
        
        return audio * scale
    
    def apply_envelope(self, sound: np.ndarray, 
                      attack: float = 0.01, 
                      decay: float = 0.1, 
                      sustain: float = 0.7, 
                      release: float = 0.3,
                      sustain_time: Optional[float] = None) -> np.ndarray:
        """
        Apply ADSR envelope to a sound.
        
        Args:
            sound: Input sound
            attack: Attack time in seconds
            decay: Decay time in seconds
            sustain: Sustain level (0-1)
            release: Release time in seconds
            sustain_time: Optional sustain duration (if None, uses remaining time)
            
        Returns:
            Sound with envelope applied
        """
        total_samples = len(sound)
        sr = self.sample_rate
        
        # Calculate sample counts
        attack_samples = int(attack * sr)
        decay_samples = int(decay * sr)
        release_samples = int(release * sr)
        
        if sustain_time is not None:
            sustain_samples = int(sustain_time * sr)
        else:
            sustain_samples = total_samples - attack_samples - decay_samples - release_samples
            sustain_samples = max(0, sustain_samples)
        
        # Create envelope
        envelope = np.ones(total_samples)
        current_sample = 0
        
        # Attack
        if attack_samples > 0 and current_sample < total_samples:
            end_sample = min(current_sample + attack_samples, total_samples)
            samples = end_sample - current_sample
            envelope[current_sample:end_sample] = np.linspace(0, 1, samples)
            current_sample = end_sample
        
        # Decay
        if decay_samples > 0 and current_sample < total_samples:
            end_sample = min(current_sample + decay_samples, total_samples)
            samples = end_sample - current_sample
            envelope[current_sample:end_sample] = np.linspace(1, sustain, samples)
            current_sample = end_sample
        
        # Sustain
        if sustain_samples > 0 and current_sample < total_samples:
            end_sample = min(current_sample + sustain_samples, total_samples)
            envelope[current_sample:end_sample] = sustain
            current_sample = end_sample
        
        # Release
        if release_samples > 0 and current_sample < total_samples:
            end_sample = min(current_sample + release_samples, total_samples)
            samples = end_sample - current_sample
            envelope[current_sample:end_sample] = np.linspace(sustain, 0, samples)
            current_sample = end_sample
        
        # Set any remaining samples to 0
        if current_sample < total_samples:
            envelope[current_sample:] = 0
        
        return sound * envelope
    
    def crossfade(self, sound1: np.ndarray, sound2: np.ndarray, 
                  overlap: float = 0.1) -> np.ndarray:
        """
        Crossfade between two sounds.
        
        Args:
            sound1: First sound
            sound2: Second sound
            overlap: Overlap duration in seconds
            
        Returns:
            Crossfaded result
        """
        overlap_samples = int(overlap * self.sample_rate)
        
        # If no overlap or sounds too short, just concatenate
        if overlap_samples <= 0 or len(sound1) < overlap_samples or len(sound2) < overlap_samples:
            return np.concatenate([sound1, sound2])
        
        # Calculate lengths
        fade_out_start = len(sound1) - overlap_samples
        total_length = len(sound1) + len(sound2) - overlap_samples
        
        # Create output
        output = np.zeros(total_length)
        
        # Copy non-overlapping parts
        output[:fade_out_start] = sound1[:fade_out_start]
        output[len(sound1):] = sound2[overlap_samples:]
        
        # Create crossfade
        fade_out = np.linspace(1, 0, overlap_samples)
        fade_in = np.linspace(0, 1, overlap_samples)
        
        output[fade_out_start:len(sound1)] = (
            sound1[fade_out_start:] * fade_out +
            sound2[:overlap_samples] * fade_in
        )
        
        return output