from __future__ import absolute_import, print_function, unicode_literals

from ableton.v2.control_surface.input_control_element import *
import socket
import select
import struct
import errno
import sys
import threading
import time


class AbletonSocketServer:
    def __init__(self, parent, host='0.0.0.0', port=6000):
        self.parent_instance = parent
        self.host = host
        self.port = port
        self.server = None
        self.client_sockets = {}
        self.new_clients = []
        self.running = False
        self.port_check_counter = 0
        self.song_pool_refreshed = False
        self.mapping_needs_refresh = False
        self.cue_jump_mode_needs_refresh = False
        self.bpm_needs_init = False
        self.server_id = u"CUE2LIVE_SERVER"
        self.encoded_id = self.server_id.encode("utf-8") if isinstance(self.server_id, str) else self.server_id
        self.commands = {
            1: "play",
            2: "stop",
            3: "previous",
            4: "next"
        }
        self.current_tempo = 120.0
        
    def start_server(self):
        """Start the server and schedule periodic checks."""
        self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.server.setblocking(False)  # Non-blocking mode
        self.server.bind((self.host, self.port))
        self.running = True

            
    
    def _poll_socket(self, timer=None):
        """Poll for new connections and process existing client data."""
        if not self.running:
            return
            

        self.port_check_counter += 1
        if self.port_check_counter >= 10:
            self.port_check_counter = 0
            new_port = self.parent_instance.socket_manager.load_port()
            if new_port != self.port:
                self.parent_instance._tasks.add(task.run(self.handle_port_change(new_port)))
                return

        # Accept new clients
        try:
            client, address = self.server.accept()
            client_ip = address[0]
            self.send_payload(client, self.encoded_id)
            client.setblocking(False)
            self.client_sockets[client] = {
                "ip": client_ip,
                "validated": False,
                "handshake_attempts": 0  # Optional: to avoid infinite retries
            }
            self.song_pool_refreshed = True
            self.mapping_needs_refresh = True
            self.cue_jump_mode_needs_refresh = True
            self.bpm_needs_init = True

        except (socket.error, IOError):
            pass  # No new connections


                # Process data from existing clients
        for client in list(self.client_sockets):  # Use list() to safely modify during iteration
            try:
                flags = socket.MSG_PEEK if hasattr(socket, "MSG_PEEK") else 0
                data = client.recv(1, flags)

                if not data:
                    self.parent_instance._log("Client disconnected.")
                    del self.client_sockets[client]
                    client.close()
                else:
                    if not self.client_sockets[client].get("validated", False):
                        # Attempt to validate handshake
                        try:
                            handshake_data = client.recv(64)
                            if b"CUE2LIVE_CLIENT" in handshake_data:
                                self.client_sockets[client]["validated"] = True
                                self.parent_instance._log("Client handshake validated.")
                                self.new_clients.append(client)
                            else:
                                self.parent_instance._log("Invalid handshake from client.")
                                del self.client_sockets[client]
                                client.close()
                                continue
                        except (socket.error, IOError) as e:
                            self.parent_instance._log("Handshake error: {}".format(e))
                            continue  # Wait until the next poll to try again
                    else:
                        self._handle_client(client)

            except (socket.error, IOError) as e:
                if hasattr(e, 'errno') and e.errno in (errno.EAGAIN, errno.EWOULDBLOCK):
                    pass
                elif hasattr(e, 'errno') and e.errno == errno.ECONNRESET:
                    self.parent_instance._log("Client forcefully disconnected.")
                    del self.client_sockets[client]
                    client.close()
                else:
                    self.parent_instance._log("Socket error: {}".format(e))
                    del self.client_sockets[client]
                    client.close()

            except Exception as e:
                self.parent_instance._log("Error handling client: {}".format(e))
                del self.client_sockets[client]
                client.close()

        
        # Send updates if needed
        for client in self.client_sockets:

            if self.song_pool_refreshed and client != self.excluded_client:
                try:
                    # Build the payload
                    payload = bytearray()

                    # Section 1: Song Pool Update
                    
                    self._build_song_pool_payload(payload)

                    # Section 2: Setlist Update
                    self._build_setlist_payload(payload)
                    

                    # Send the combined payload
                    self.send_payload(client, payload)
                                # Reset the refresh flags
                    


                except Exception as e:
                    self.parent_instance._log("Failed to send update: {}".format(e))
            
            if self.mapping_needs_refresh:
                self.send_midi_map_init(client)
                
            if self.cue_jump_mode_needs_refresh:
                self.send_cue_jump_mode(client)
                
            if self.parent_instance.song.is_playing and (self.parent_instance.song.tempo != self.current_tempo or self.bpm_needs_init):
                self.send_tempo(client)
            
            if client in self.new_clients:
                self.send_setlist_index_singular(client, self.parent_instance.set_man.setlist_index)
                self.new_clients.remove(client)
            
            # Heartbeat
            try:
                header = struct.pack("B", 2)
                payload = bytearray(header)
                self.send_payload(client, payload)
            except Exception as e:
                self.parent_instance._log("Failed to send heartbeat: {}".format(e))
            
                
                
        self.song_pool_refreshed = False
        self.mapping_needs_refresh = False
        self.cue_jump_mode_needs_refresh = False
        self.bpm_needs_init = False
        self.current_tempo = self.parent_instance.song.tempo
        self.excluded_client = None
        if self.parent_instance.loop_update:
            self.parent_instance.update_loop()

                

                
        # Schedule the next check
        self.parent_instance._tasks.add(
            task.sequence(task.DelayTask(0.1), task.run(self._poll_socket)))



    def _build_song_pool_payload(self, payload):
        """Builds the song pool update section of the payload."""
        song_pool_header = struct.pack("B", 3)  # Header for Song Pool Update
        payload.extend(song_pool_header)

        for song_id, song_data in self.parent_instance.set_man.song_pool.items():
            name = song_data["name"].encode("utf-8")
            name_length = len(name)
            
            # Assume length is stored as a string and encode it (e.g., "03:45")
            length = song_data["length"].encode("utf-8")
            length_length = len(length)  # Length of the length string

            # Pack the song ID (UInt64), name length (UInt8), name, length length (UInt8), and length
            payload.extend(struct.pack(">Q", song_id))  # Song ID as UInt64
            payload.extend(struct.pack("B", name_length))  # Name length as UInt8
            payload.extend(name)  # Actual name bytes
            payload.extend(struct.pack("B", length_length))  # Length string length as UInt8
            payload.extend(length)  # Actual length bytes



    def _build_setlist_payload(self, payload):
        """Builds the setlist update section of the payload."""
        setlist_header = struct.pack("B", 4)  # Header for Setlist Update
        payload.extend(setlist_header)

        if self.parent_instance.set_man.setlist is not None:
            for song in self.parent_instance.set_man.setlist:
                song_id = song["id"]
                song_name = song["name"]
                song_length = song["length"]  # Assume the length is stored as a string like "03:45"

                # Encode the name and length
                name = song_name.encode("utf-8")
                name_length = len(name)

                length = song_length.encode("utf-8")
                length_length = len(length)

                # Pack the song ID (UInt64), name length (UInt8), name, length length (UInt8), and length
                payload.extend(struct.pack(">Q", song_id))  # Song ID as UInt64
                payload.extend(struct.pack("B", name_length))  # Name length as UInt8
                payload.extend(name)  # Actual name bytes
                payload.extend(struct.pack("B", length_length))  # Length string length as UInt8
                payload.extend(length)  # Actual length bytes
        else:
            return


    def send_payload(self, client, payload):
        payload_length = len(payload)
        header = struct.pack(">I", payload_length)  # Big-endian unsigned int
        client.sendall(header + payload)  # Send length prefix followed by payload

    def send_midi_map_init(self, client):
        try:
            # Create the header for the MIDI map initialization
            header = struct.pack("B", 0)  # Header value 0 for MIDI map init
            payload = bytearray(header)

            # Prepare a log message to show all mappings being sent
            log_message = "Sending MIDI Map Initialization:\n"

            # Iterate through control mappings and append their data to the payload
            for control_name, mapping in self.parent_instance.map_comp.control_mappings.items():
                if mapping is not None:
                    msg_type = mapping["msg_type"]  # MIDI message type (e.g., Note On/Off)
                    channel = mapping["channel"]   # MIDI channel
                    identifier = mapping["identifier"]  # MIDI control number
                    # Pack the control name length and name
                    name_bytes = control_name.encode("utf-8")
                    name_length = len(name_bytes)
                    payload.extend(struct.pack("B", name_length))  # Control name length
                    payload.extend(name_bytes)                    # Control name bytes

                    # Pack the MIDI mapping data
                    payload.extend(struct.pack(">BBH", msg_type, channel, identifier))

                    # Append to the log message
                    log_message += "Control: {}, msgType: {}, Channel: {}, Identifier: {}\n".format(
                        control_name, msg_type, channel, identifier
                    )
                else:
                    pass

            # Log the payload content
            self.parent_instance._log(log_message)

            # Send the payload
            try:
                self.send_payload(client, payload)
            except socket.error as e:
                self.parent_instance._log("Failed to send payload: {}".format(e))
                if e.errno == errno.EPIPE:  # Broken pipe
                    self.parent_instance._log("Removing broken client.")
                    self.client_sockets.remove(client)
                    client.close()
        except Exception as e:
            self.parent_instance._log("Failed to send MIDI map initialization: {}".format(e))

    def send_setlist_index_singular(self, client, index):
        try:
            header = struct.pack("B", 5)
            payload = bytearray(header)
            payload.extend(struct.pack("B", index))
            self.send_payload(client, payload)
        except Exception as e:
            self.parent_instance._log("Failed to send setlist index update: {}".format(e))
    
    def send_setlist_index_update(self, index):
        for client in self.client_sockets:
            try:
                header = struct.pack("B", 5)
                payload = bytearray(header)
                payload.extend(struct.pack("B", index))
                self.send_payload(client, payload)
            except Exception as e:
                self.parent_instance._log("Failed to send setlist index update: {}".format(e))
        
    def send_cue_jump_mode(self, client):
        try:
            header = struct.pack("B", 7)
            payload = bytearray(header)
            payload.extend(struct.pack("B", self.parent_instance.current_cue_mode))
            self.send_payload(client, payload)
        except Exception as e:
            self.parent_instance._log("Failed to send cue jump mode update: {}".format(e))
            
    def send_tempo(self, client):
        try:
            header = struct.pack("B", 8)
            payload = bytearray(header)
            processed_bpm = int((round(self.parent_instance.song.tempo, 2) * 100))
            payload.extend(struct.pack(">I", processed_bpm))
            self.send_payload(client, payload)
        except Exception as e:
            self.parent_instance._log("Failed to send tempo update: {}".format(e))
            
    def send_is_playing(self, playing):
        for client in self.client_sockets:
            try:
                header = struct.pack("B", 1)
                payload = bytearray(header)
                if playing:
                    bool = struct.pack("B", 1)
                else:
                    bool = struct.pack("B", 2)
                payload.extend(bool)
                self.send_payload(client, payload)
            except Exception as e:
                self.parent_instance._log("Failed to send is playing update: {}".format(e))
                    
    def _handle_client(self, client):
        try:
            header = client.recv(1)
            if not header:
                self.parent_instance._log("No data received from client.")
                return  # Return without closing the connection


            msg_type = struct.unpack("B", header)[0]  # 0 = MIDI, 1 = Command, 2 = Remove Mapping, 3 = Index Update, 4 = Setlist Update
            self.parent_instance._log("Received header: {}".format(msg_type))


            if msg_type == 0:  # MIDI Message
                midi_data = client.recv(5)
                if len(midi_data) != 5:
                    self.parent_instance._log("Invalid MIDI message length: {}".format(len(midi_data)))
                    return  # Invalid data, but keep connection open


                channel, msg_type, value, control = struct.unpack(">BBHB", midi_data)
                self.handle_midi(control, msg_type, channel, value)


            elif msg_type == 1:  # Command
                command_data = client.recv(1)
                if not command_data:
                    self.parent_instance._log("No command data received.")
                    return  # Keep connection open


                command_id = struct.unpack("B", command_data)[0]
                self.handle_command(command_id)


            elif msg_type == 2:  # Remove Mapping
                remove_data = client.recv(1)
                if not remove_data:
                    self.parent_instance._log("No remove data received.")
                    return  # Keep connection open
                control = struct.unpack("B", remove_data)[0]
                self.remove_mapping(control)


            elif msg_type == 3:  # Setlist index update
                index_data = client.recv(1)
#                if isinstance(index_data, str):
#                        data = bytearray(data)
                if len(index_data) != 1:
                    raise ValueError("Invalid index data length")
                current_index = struct.unpack("B", index_data)[0]
                self.parent_instance.set_man.setlist_index = current_index
                self.parent_instance._log("Setlist index updated to {}".format(current_index))


            elif msg_type == 4:  # Setlist update
                index_update = None
                try:
                    data = client.recv(4096)  # Adjust the size as needed to fit your payload
                except Exception as e:
                    self.parent_instance._log("Setlist cleared.")
                    self.parent_instance.set_man.update_setlist([], client)
                    return
                try:
                    offset = 0
                    songs = []
                    
                    if isinstance(data, str):  # In Python 2, recv() returns str
                        data = bytearray(data)  # Convert to bytearray for compatibility

                    # Proceed with parsing setlist data
                    while offset < len(data):
                        # Check if there's enough data for song ID (8 bytes)
                        if offset + struct.calcsize(">Q") > len(data):
                            # Check if the remaining data is an index update instead of a song
                            if len(data[offset:]) == 2 and data[offset] == 3:
                                index_update = struct.unpack_from("B", data, offset + 1)[0]
                                break
                            elif len(data[offset:]) == 4 and data[offset] == 3 and data[offset + 2] == 3:
                                index_update = struct.unpack_from("B", data, offset + 3)[0]
                                break
                            else:
                                self.parent_instance._log("{}".format(data))
                                raise ValueError("Insufficient data for song ID at offset {}".format(offset))


                        # Read the song ID (UInt64, big-endian)
                        song_id = struct.unpack_from(">Q", data, offset)[0]
                        offset += struct.calcsize(">Q")
                        self.parent_instance._log("Offset after song ID: {}".format(offset))


                        # Check if there's enough data for name length (1 byte)
                        if offset + struct.calcsize("B") > len(data):
                            raise ValueError("Insufficient data for name length at offset {}".format(offset))


                        # Read the name length (UInt8)
                        name_length = struct.unpack_from("B", data, offset)[0]
                        offset += struct.calcsize("B")


                        # Check if there's enough data for the name (name_length bytes)
                        if offset + name_length > len(data):
                            raise ValueError("Insufficient data for name at offset {}".format(offset))
                            


                        # Read the name (UTF-8 string)
                        name = struct.unpack_from("{}s".format(name_length), data, offset)[0].decode("utf-8")
                        offset += name_length
                        

                        # Add the song to the list
                        songs.append({"id": song_id, "name": name, "length": self.parent_instance.set_man.song_pool[song_id]["length"]})
                        self.parent_instance._log("Parsed Song: ID={}, Name={}".format(song_id, name))


                    # Update the setlist
                    self.parent_instance.set_man.update_setlist(songs, client)
                    if index_update is not None:
                        self.parent_instance.set_man.setlist_index = index_update
                    self.parent_instance._log("Setlist updated with {} songs.".format(len(songs)))


                except Exception as e:
                    self.parent_instance._log("Error unpacking setlist data: {}".format(e))
                    
            elif msg_type == 5:
                data = client.recv(1024)
                offset = 0
                while offset < len(data):
                    # Read the song ID (UInt64, big-endian)
                    song_id = struct.unpack_from(">Q", data, offset)[0]
                    self.handle_audible(song_id)
                    return
                    
            elif msg_type == 6:
                timestamp_bytes = client.recv(8)
                sent_time = struct.unpack(">Q", timestamp_bytes)[0]
                self.parent_instance._log("Received timestamp: {}".format(sent_time))
                payload = header + timestamp_bytes
                self.send_payload(client, payload)
                

                        
            elif msg_type == 7:
                data = client.recv(1024)
                offset = 0
                while offset < len(data):
                    # Read the cue mode selection number
                    modeNum = struct.unpack("B", data)[0]
                    self.handle_cue_mode_change(modeNum)
                    return
                    
            elif msg_type == 8:
                self.activate_lk()
                return
        
                
        except Exception as e:
            self.parent_instance._log("Error handling client: {}".format(e))

    def handle_midi(self, control, msg_type, channel, value):
        """
        Processes MIDI messages and maps them to the appropriate control.
        """
        msg_type_mapping = {0: MIDI_NOTE_TYPE, 1: MIDI_CC_TYPE}
        
        mapped_msg_type = msg_type_mapping[msg_type]
        control_name = self.commands.get(control)

        self.parent_instance._log("Handling MIDI: Control = {}, Channel = {}, Msg_Type = {}, Value = {}".format(control_name, channel, mapped_msg_type, value))

        self.parent_instance.map_comp.add_mapping(
            control_name=control_name,
            msg_type=mapped_msg_type,
            channel=channel,
            identifier=value
        )
        
    def remove_mapping(self, control):
        
        control_name = self.commands.get(control)
        self.parent_instance.remove_control(control_name)
    
    def handle_command(self, command_id):
        """
        Processes commands.
        """
        command = self.commands.get(command_id)
        if command:
            self.parent_instance._log("Handling Command: {}".format(command))
            handler = getattr(self.parent_instance, "handle_{}".format(command), None)
            if handler:
                handler(127)
            else:
                self.parent_instance._log("No handler for command: {}".format(command))
        else:
            self.parent_instance._log("Unknown command ID: {}".format(command_id))
            
    def handle_audible(self, song_id):
        for cue in self.parent_instance.song.cue_points:
            if hash(cue) == song_id:
                self.parent_instance.jump_to_and_play(cue)
                
    def handle_port_change(self, port):
        if port != self.port and self.server:
            self.port = port
            self.stop_server()
            self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self.server.setblocking(False)  # Non-blocking mode
            self.server.bind((self.host, port))
            self.running = True
            self.server.listen(5)
            self.parent_instance._tasks.add(task.run(self._poll_socket))
        else:
            return
        
    def handle_cue_mode_change(self, modeNum):
        self.parent_instance.current_cue_mode = modeNum
        
    def activate_lk(self):
        self.parent_instance.lk_active = True
    

    def stop_server(self):
        self.running = False
        if self.server:
            try:
                self.running = False
                self.server.shutdown(socket.SHUT_RDWR)  # Cleanly shut down both directions
                self.server.close()  # Close the server socket
                self.server = None
                self.parent_instance._log("Socket server stopped.")
            except Exception as e:
                self.parent_instance._log("Error while stopping the server: {}".format(e))


    # Compatibility Adjustments for Dictionary Iteration
    def iteritems(d):
        if sys.version_info[0] == 2:
            return d.iteritems()
        return d.items()
