"""
Embedded Python Blocks:

Each time this file is saved, GRC will instantiate the first class it finds
to get ports and parameters of your block. The arguments to __init__  will
be the parameters. All of them are required to have default values!
"""

import numpy as np
from gnuradio import gr


class blk(gr.basic_block):  # other base classes are basic_block, decim_block, interp_block
    """Embedded Python Block example - a simple multiply const"""

    def __init__(self, spreading_factor=15, mtu=127, dsss_access_code_length=4*15, dsss_phr_length=6*8*15 ):  # only default arguments here
        """arguments to this function show up as parameters in GRC"""
        gr.sync_block.__init__(
            self,
            name='Detect preamble - Align output',   # will show up in GRC
            in_sig=[np.byte],
            out_sig=[np.byte]
        )
        # if an attribute with the same name as a parameter is found,
        # a callback is registered (properties work, too).
        self.spreading_factor = spreading_factor
        self.mtu = mtu
        self.dsss_access_code_length = dsss_access_code_length
        self.dsss_access_code_chip_length = int(self.dsss_access_code_length / self.spreading_factor)
        self.dsss_phr_length = dsss_phr_length
        self.offset_synced = False
        self.last_synced_offset = 0
        self.tagRelativeOffset = 0
        self.dsss_byte = self.spreading_factor * 8
        self.dsss_byte_alignment_factor = self.dsss_byte * 2
        self.previously_synced = False
        self.count = 0
        self.first_nonmatch_index = 0
        self.chips = np.array([ [1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0],
                       [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1] ], dtype=np.uint8)
        self.last_written_chip = self.chips[0]
        self.max_input_packet = self.dsss_phr_length + self.mtu *self.dsss_byte
        self.stalling_pos = 0
        self.stalling_counter = 0
        self.stalling_limit = 32
        self.fresh_full_resync = False

    def work(self, input_items, output_items):
        """Detect incoming PHR and align to dsss-byte margin """
        in_buff_length = len(input_items[0])
        input_length = min( in_buff_length , self.dsss_phr_length + 8*self.dsss_byte )
        tag_window_length = 0
        if ( in_buff_length > 4096<<2 ):
            tag_window_length = 2048<<2
        else:
            tag_window_length = in_buff_length
        
        tag_window_length = min(in_buff_length, self.max_input_packet<<1)

        tags = self.get_tags_in_window(0, 0, tag_window_length)
        tags_length = len(tags)

        if ( self.offset_synced and (input_length >= self.dsss_byte) ) :
            if ( (self.nitems_read(0) - self.last_synced_offset) >= (self.mtu * self.dsss_byte) ):
                self.offset_synced = False

            if ( self.fresh_full_resync ):
                pre_read = self.nitems_read(0)
                relevant_tag_index = 0
                if ( tags_length > 1 and (tags[relevant_tag_index].offset - pre_read) == 0 ):  # we already read the four zero chips
                    relevant_tag_index = 1                                             # need to check the subsequent tag (if availabe)
                if ( tags_length > 0 and (tags[relevant_tag_index].offset - pre_read) >= (self.dsss_byte + self.dsss_access_code_length) ):  # if the next 4 consecutive zero chips are at least a byte away, we are synced, pipe it all out
                    self.fresh_full_resync = False
                else: # now we have some zero chips inside the first byte
                    self.offset_synced = False
                    toConsume = 0
                    theOffset = -1
                    if (tags_length > 0): # if at least one tag, then at least 4 zero chips, consume them directly
                        theOffset = tags[relevant_tag_index].offset - pre_read
                        if ( (tags[relevant_tag_index].offset - pre_read) < 0 ): # but check if we maybe partially read them before
                            toConsume = self.dsss_access_code_length + tags[relevant_tag_index].offset - pre_read
                        else:
                            toConsume = tags[relevant_tag_index].offset - pre_read
                    else: # no tags means less than 4 zero chips
                        i = 0 # counter
                        while ( i < self.dsss_access_code_chip_length and np.array_equal(input_items[0][i*self.spreading_factor:i*self.spreading_factor + self.spreading_factor], self.chips[0]) ) :
                            toConsume += self.spreading_factor # consume the first <4 consecutive zero chips
                            i += 1 # counter++
                    self.fresh_full_resync = False
                    self.consume_each(toConsume)  # consume consecutive 0 bits
                    return 0  # but don't write them

            return self.consume_and_write_spread(input_items, output_items, factor=8)
        
        elif (self.offset_synced):
            return 0
        
        elif (input_length >= self.dsss_phr_length+self.spreading_factor):
            if (tags_length >= 25):  # first tag offset covers 4 bits, in total 32 are needed => 29 remain
                previous_tag_offset_abs = tags[0].offset
                self.first_nonmatch_index = 0

                for idx, i in enumerate(tags[1:]):
                    if (((i.offset - previous_tag_offset_abs) != self.spreading_factor)):
                        self.first_nonmatch_index = idx+1
                        break
                    else:
                        previous_tag_offset_abs = i.offset

                if ( self.first_nonmatch_index == 0 ):
                    if ( tags_length > 42 ):
                        return self.consume_and_write_spread(input_items, output_items, factor=8)
                    elif ( tags_length >= 25 ):
                        self.first_nonmatch_index = tags_length

                if (self.first_nonmatch_index <= 24): # allow for some corruption (default 28)
                    self.tagRelativeOffset = tags[self.first_nonmatch_index].offset - self.nitems_read(0) - self.dsss_access_code_length
                else:
                    shift_for_missing = 0
                    if (self.first_nonmatch_index < 29):
                        # compensate for missing overlap
                        shift_for_missing = 29 - self.first_nonmatch_index
                        # 0<shift_for_missing<5
                    self.tagRelativeOffset = tags[self.first_nonmatch_index - (29 - shift_for_missing) ].offset - self.nitems_read(0) - self.dsss_access_code_length

                if (self.tagRelativeOffset > 0):
                    toConsume = self.tagRelativeOffset
                    
                    if ( self.previously_synced ):
                        offset = (self.nitems_written(0) + toConsume) % self.spreading_factor
                        if ( self.first_nonmatch_index >= 25 ):
                            output_items[0][:toConsume] = input_items[0][:toConsume]
                            total_written = toConsume
                            last_chip_len = 0
                            if (toConsume >= self.spreading_factor):
                                last_chip_len = self.spreading_factor
                                output_items[0][toConsume:toConsume+last_chip_len] = input_items[0][toConsume-last_chip_len:toConsume]
                            else:
                                last_chip_len = len(self.last_written_chip)
                                output_items[0][toConsume:toConsume+last_chip_len] = self.last_written_chip
                            # make sure to start a new chip (of a potentially new packet) at a byte aligned offset
                            toPad = (self.dsss_byte_alignment_factor - ((self.nitems_written(0) + toConsume + last_chip_len) % self.dsss_byte_alignment_factor)) % self.dsss_byte_alignment_factor
                            toPad += 1*self.dsss_byte
                            for i in range(toPad):
                                output_items[0][toConsume+last_chip_len+i] = self.chips[0][(offset + i) % self.spreading_factor] # write (partial) 0-chip(s)
                            total_written += toPad + last_chip_len
                            self.last_written_chip = self.chips[0]
                        else:
                            output_items[0][offset:toConsume+offset] = input_items[0][:toConsume]
                            total_written = toConsume
                            toPad = ((self.spreading_factor) - offset) % self.spreading_factor
                            for i in range(toPad):
                                output_items[0][i] = self.chips[0][(offset + i) % self.spreading_factor] # write (partial) 0-chip(s)
                            total_written += toPad

                    else:
                        total_written = 0

                    self.consume_each(toConsume)
                    return total_written
                    
                else:
                    if (self.first_nonmatch_index >= 25):
                        self.offset_synced = True
                        self.last_synced_offset = self.tagRelativeOffset + self.nitems_read(0)
                        self.previously_synced = True
                        self.fresh_full_resync = False
                    total_written = 0
                    fractional_chip = 0
                    offset = self.tagRelativeOffset % self.spreading_factor
                    current_position_in_byte = self.nitems_written(0) % self.dsss_byte

                    # adjust input length (to be written out) to be byte-aligned
                    input_length = input_length - ((total_written + input_length + current_position_in_byte) % (self.dsss_byte))

                    toPad = 0

                    if (self.first_nonmatch_index >= 25 and tags_length == self.first_nonmatch_index):
                        self.fresh_full_resync = True

                        pre_written_chips_balance = - self.tagRelativeOffset - current_position_in_byte

                        if (pre_written_chips_balance >= 0):
                            toPad += self.dsss_byte
                            toPad += pre_written_chips_balance
                        else:
                            toPad += (self.dsss_byte_alignment_factor - ((current_position_in_byte) % self.dsss_byte_alignment_factor)) % self.dsss_byte_alignment_factor
                            toPad += -self.tagRelativeOffset

                        toPad += 1*self.dsss_byte
                        for i in range(toPad):
                            output_items[0][i] = self.chips[0][( (current_position_in_byte % self.spreading_factor) + i) % self.spreading_factor] # write (partial) 0-chip(s)
                        total_written += toPad
                        lastTagOffset = tags[self.first_nonmatch_index-1].offset - self.nitems_read(0)
                        input_length = lastTagOffset + (self.dsss_byte - lastTagOffset) % self.dsss_byte

                    output_items[0][total_written:total_written+input_length] = input_items[0][:input_length]

                    total_written += input_length
                    self.consume_each(input_length)
                    return total_written
            
            elif (self.previously_synced):
                if (len(input_items[0]) >= 8*self.spreading_factor):
                    if (tags_length < 1) or ((tags[0].offset - self.nitems_read(0) - self.dsss_access_code_length) > self.dsss_byte):  # lazy eval
                        return self.consume_and_write_spread(input_items, output_items, factor=8)
                    if ((tags_length >= 3) and ((tags[0].offset - self.nitems_read(0) - self.dsss_access_code_length) <= self.dsss_byte) and (((tags[1].offset - tags[0].offset) == self.spreading_factor) or ((tags[1].offset - tags[0].offset) == self.spreading_factor<<1) )):
                        if (self.stalling_pos == self.nitems_read(0)):
                            self.stalling_counter += 1
                        else:
                            self.stalling_pos = self.nitems_read(0)
                            self.stalling_counter = 0
                        if (self.stalling_counter < self.stalling_limit):
                            return 0
                    return self.consume_and_write_spread(input_items, output_items, factor=4)
                else:
                    return 0
            
            elif (tags_length > 0): # assumes not previously synced
                consuming = tags[0].offset - self.nitems_read(0) - self.dsss_access_code_length
                unsynced_written = 0
                if (consuming < 0):
                    unsynced_written = -1*consuming
                    consuming = tags[0].offset - self.nitems_read(0)
                    for i in range(unsynced_written):
                        output_items[0][i] = input_items[0][consuming+(i%self.spreading_factor)]
                    output_items[0][unsynced_written:unsynced_written+consuming] = input_items[0][:consuming]
                    unsynced_written += consuming
                self.previously_synced = True
                self.consume_each(consuming)
                return unsynced_written
            
            else:
                if (input_length >= self.dsss_phr_length+self.spreading_factor):
                    self.consume_each(input_length-(self.dsss_phr_length+self.spreading_factor))
                return 0

        else:
            return 0

    def consume_and_write_spread(self, input_items, output_items, factor=8):
        output_items[0][:self.spreading_factor*factor] = input_items[0][:self.spreading_factor*factor]
        self.consume_each(self.spreading_factor*factor)
        return self.spreading_factor*factor
    

