#
# This file is part of LiteSATA.
#
# Copyright (c) 2015-2019 Florent Kermarrec <florent@enjoy-digital.fr>
# SPDX-License-Identifier: BSD-2-Clause

from litesata.common import *


def test_type(name, signal):
    return signal == fis_types[name]

# Transport TX -------------------------------------------------------------------------------------

class LiteSATATransportTX(Module):
    def __init__(self, link):
        self.sink = sink = stream.Endpoint(transport_tx_description(32))

        # # #

        cmd_ndwords = max(fis_reg_h2d_header.length, fis_data_header.length)
        encoded_cmd = Signal(cmd_ndwords*32)

        counter       = Signal(max=cmd_ndwords+1)
        counter_ce    = Signal()
        counter_reset = Signal()
        self.sync += \
            If(counter_reset,
                counter.eq(0)
            ).Elif(counter_ce,
                counter.eq(counter + 1)
            )

        cmd_len       = Signal(len(counter))
        cmd_with_data = Signal()

        cmd_send  = Signal()
        data_send = Signal()
        cmd_done  = Signal()

        fis_type        = Signal(8)
        update_fis_type = Signal()

        def test_type_tx(name):
            return test_type(name, sink.type)

        self.fsm = fsm = FSM(reset_state="IDLE")
        self.submodules += fsm
        fsm.act("IDLE",
            sink.ready.eq(0),
            counter_reset.eq(1),
            update_fis_type.eq(1),
            If(sink.valid,
                If(test_type_tx("REG_H2D"),
                    NextState("SEND_CTRL_CMD")
                ).Elif(test_type_tx("DATA"),
                    NextState("SEND_DATA_CMD")
                ).Else(
                    sink.ready.eq(1)
                )
            ).Else(
                sink.ready.eq(1)
            )
        )
        self.sync += \
            If(update_fis_type, fis_type.eq(link.source.data[:8]))

        fsm.act("SEND_CTRL_CMD",
            fis_reg_h2d_header.encode(sink, encoded_cmd),
            cmd_len.eq(fis_reg_h2d_header.length-1),
            cmd_send.eq(1),
            If(cmd_done,
                sink.ready.eq(1),
                NextState("IDLE")
            )
        )
        fsm.act("SEND_DATA_CMD",
            sink.ready.eq(0),
            fis_data_header.encode(sink, encoded_cmd),
            cmd_len.eq(fis_data_header.length-1),
            cmd_with_data.eq(1),
            cmd_send.eq(1),
            If(cmd_done,
                NextState("SEND_DATA")
            )
        )
        fsm.act("SEND_DATA",
            data_send.eq(1),
            sink.ready.eq(link.sink.ready),
            If(sink.valid & sink.last & sink.ready,
                NextState("IDLE")
            )
        )

        cmd_cases = {}
        for i in range(cmd_ndwords):
            cmd_cases[i] = [link.sink.data.eq(encoded_cmd[32*i:32*(i+1)])]

        self.comb += [
            counter_ce.eq(sink.valid & link.sink.ready),
            cmd_done.eq((counter == cmd_len) &
                        link.sink.valid &
                        link.sink.ready),
            If(cmd_send,
                link.sink.valid.eq(sink.valid),
                link.sink.last.eq((counter == cmd_len) & ~cmd_with_data),
                Case(counter, cmd_cases)
            ).Elif(data_send,
                link.sink.valid.eq(sink.valid),
                link.sink.last.eq(sink.last),
                link.sink.data.eq(sink.data)
            )
        ]

# Transport RX -------------------------------------------------------------------------------------

class LiteSATATransportRX(Module):
    def __init__(self, link):
        self.source = source = stream.Endpoint(transport_rx_description(32))

        # # #

        cmd_ndwords = max(fis_reg_d2h_header.length,
                          fis_dma_activate_d2h_header.length,
                          fis_pio_setup_d2h_header.length,
                          fis_data_header.length)
        encoded_cmd = Signal(cmd_ndwords*32)

        counter       = Signal(max=cmd_ndwords+1)
        counter_ce    = Signal()
        counter_reset = Signal()
        self.sync += \
            If(counter_reset,
                counter.eq(0)
            ).Elif(counter_ce,
                counter.eq(counter + 1)
            )

        cmd_len = Signal(len(counter))

        cmd_receive  = Signal()
        data_receive = Signal()
        cmd_done     = Signal()
        data_done    = Signal()

        def test_type_rx(name):
            return test_type(name, link.source.data[:8])

        self.fsm = fsm = FSM(reset_state="IDLE")
        self.submodules += fsm

        fis_type        = Signal(8)
        update_fis_type = Signal()

        fsm.act("IDLE",
            link.source.ready.eq(0),
            counter_reset.eq(1),
            update_fis_type.eq(1),
            If(link.source.valid,
                If(test_type_rx("REG_D2H"),
                    NextState("RECEIVE_CTRL_CMD")
                ).Elif(test_type_rx("DMA_ACTIVATE_D2H"),
                    NextState("RECEIVE_CTRL_CMD")
                ).Elif(test_type_rx("PIO_SETUP_D2H"),
                    NextState("RECEIVE_CTRL_CMD")
                ).Elif(test_type_rx("DATA"),
                    NextState("RECEIVE_DATA_CMD"),
                ).Else(
                    link.source.ready.eq(1)
                )
            ).Else(
                link.source.ready.eq(1)
            )
        )
        self.sync += \
            If(update_fis_type, fis_type.eq(link.source.data[:8]))

        fsm.act("RECEIVE_CTRL_CMD",
            If(test_type("REG_D2H", fis_type),
                cmd_len.eq(fis_reg_d2h_header.length-1)
            ).Elif(test_type("DMA_ACTIVATE_D2H", fis_type),
                cmd_len.eq(fis_dma_activate_d2h_header.length-1)
            ).Else(
                cmd_len.eq(fis_pio_setup_d2h_header.length-1)
            ),
            cmd_receive.eq(1),
            link.source.ready.eq(1),
            If(cmd_done,
                NextState("PRESENT_CTRL_CMD")
            )
        )
        fsm.act("PRESENT_CTRL_CMD",
            source.valid.eq(1),
            source.last.eq(1),
            If(test_type("REG_D2H", fis_type),
                fis_reg_d2h_header.decode(encoded_cmd, source)
            ).Elif(test_type("DMA_ACTIVATE_D2H", fis_type),
                fis_dma_activate_d2h_header.decode(encoded_cmd, source)
            ).Else(
                fis_pio_setup_d2h_header.decode(encoded_cmd, source)
            ),
            If(source.valid & source.ready,
                NextState("IDLE")
            )
        )
        fsm.act("RECEIVE_DATA_CMD",
            cmd_len.eq(fis_data_header.length-1),
            cmd_receive.eq(1),
            link.source.ready.eq(1),
            If(cmd_done,
                NextState("PRESENT_DATA")
            )
        )
        fsm.act("PRESENT_DATA",
            data_receive.eq(1),
            source.valid.eq(link.source.valid),
            fis_data_header.decode(encoded_cmd, source),
            source.last.eq(link.source.last),
            source.error.eq(link.source.error),
            source.data.eq(link.source.data),
            link.source.ready.eq(source.ready),
            If(source.valid & source.last & source.ready,
                NextState("IDLE")
            )
        )

        cmd_cases = {}
        for i in range(cmd_ndwords):
            cmd_cases[i] = [encoded_cmd[32*i:32*(i+1)].eq(link.source.data)]

        self.comb += \
            If(cmd_receive & link.source.valid,
                counter_ce.eq(1)
            )
        self.sync += \
            If(cmd_receive,
                Case(counter, cmd_cases),
            )
        self.comb += cmd_done.eq((counter == cmd_len) & link.source.ready)

# Transport ----------------------------------------------------------------------------------------

class LiteSATATransport(Module):
    def __init__(self, link):
        self.submodules.tx = LiteSATATransportTX(link)
        self.submodules.rx = LiteSATATransportRX(link)
        self.sink, self.source = self.tx.sink, self.rx.source
