diff --git a/README.md b/README.md index e6ccd51..d8c7d30 100644 --- a/README.md +++ b/README.md @@ -229,6 +229,8 @@ To use these modules, import the one you need and connect it to the DUT: The first argument to the constructor accepts an `AxiStreamBus` object. This object is a container for the interface signals and includes class methods to automate connections. +To allow `AxiStreamSource` to interleave data set the interleave parameter a dictionary containing `tid` and / or `tdest`. THe maximum interleave depth can also be set with `max_interleave_depth`. By default is is unbound. + To send data into a design with an `AxiStreamSource`, call `send()`/`send_nowait()` or `write()`/`write_nowait()`. Accepted data types are iterables or `AxiStreamFrame` objects. Optionally, call `wait()` to wait for the transmit operation to complete. Example: await axis_source.send(b'test data') @@ -246,6 +248,11 @@ To receive data with an `AxiStreamSink` or `AxiStreamMonitor`, call `recv()`/`re data = await axis_sink.recv() +To deinterleave receive data the `interleave` parameter can be set on the `AxiStreamSink` constructor. This causes calls to `read()` and `recv()` to return data sorted by `tid` ot `tdest`, returned in order of transaction completion time. + + axis_sink = AxiStreamSink(AxiStreamBus.from_prefix(dut, "m_axis"), dut.clk, dut.rst, interleave="tid") + data = await axis_sink.recv() + #### Signals * `tdata`: data, required diff --git a/cocotbext/axi/axis.py b/cocotbext/axi/axis.py index aa83d1d..fd1f015 100644 --- a/cocotbext/axi/axis.py +++ b/cocotbext/axi/axis.py @@ -33,6 +33,9 @@ from .version import __version__ from .reset import Reset +from functools import reduce +from random import choice + class AxiStreamFrame: def __init__(self, tdata=b'', tkeep=None, tid=None, tdest=None, tuser=None, tx_complete=None): @@ -261,9 +264,12 @@ class AxiStreamBase(Reset): _ready_init = None def __init__(self, bus, clock, reset=None, reset_active_level=True, - byte_size=None, byte_lanes=None, *args, **kwargs): + byte_size=None, byte_lanes=None, interleave=None, *args, **kwargs): self.bus = bus + self.interleave =interleave + if not self.interleave: + self.interleave = {} self.clock = clock self.reset = reset self.log = logging.getLogger(f"cocotb.{bus._entity._name}.{bus._name}") @@ -275,10 +281,15 @@ def __init__(self, bus, clock, reset=None, reset_active_level=True, super().__init__(*args, **kwargs) + if "tid" in self.interleave and not hasattr(self.bus, "tid"): + raise ValueError("Cannot interleave with tid on a bus without tid") + if "tdest" in self.interleave and not hasattr(self.bus, "tdest"): + raise ValueError("Cannot interleave with tdest on a bus without tdest") + self.active = False self.queue = Queue() self.dequeue_event = Event() - self.current_frame = None + self.current_frames = {} self.idle_event = Event() self.idle_event.set() self.active_event = Event() @@ -425,14 +436,20 @@ class AxiStreamSource(AxiStreamBase, AxiStreamPause): _ready_init = None def __init__(self, bus, clock, reset=None, reset_active_level=True, - byte_size=None, byte_lanes=None, *args, **kwargs): + byte_size=None, byte_lanes=None, interleave=None, max_interleave_depth=None, *args, **kwargs): - super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, *args, **kwargs) + super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, interleave, *args, **kwargs) + self.max_interleave_depth = max_interleave_depth self.queue_occupancy_limit_bytes = -1 self.queue_occupancy_limit_frames = -1 async def send(self, frame): + # If interleaving enabled, check provided frame has the required parameter(s) + if "tid" in self.interleave and (frame.tid is None or type(frame.tid) is list): + raise ValueError("Sending a frame with interleaving on tid requires single tid be associated with the frame") + if "dest" in self.interleave and (frame.tdest is None or type(frame.tdest) is list): + raise ValueError("Sending a frame with interleaving on tdest requires single tdest be associated with the frame") while self.full(): self.dequeue_event.clear() await self.dequeue_event.wait() @@ -444,6 +461,11 @@ async def send(self, frame): self.queue_occupancy_frames += 1 def send_nowait(self, frame): + # If interleaving enabled, check provided frame has the required parameter(s) + if "tid" in self.interleave and (frame.tid is None or type(frame.tid) is list): + raise ValueError("Sending a frame with interleaving on tid requires single tid be associated with the frame") + if "dest" in self.interleave and (frame.tdest is None or type(frame.tdest) is list): + raise ValueError("Sending a frame with interleaving on tdest requires single tdest be associated with the frame") if self.full(): raise QueueFull() frame = AxiStreamFrame(frame) @@ -491,14 +513,19 @@ def _handle_reset(self, state): if hasattr(self.bus, "tuser"): self.bus.tuser.value = 0 - if self.current_frame: - self.log.warning("Flushed transmit frame during reset: %s", self.current_frame) - self.current_frame.handle_tx_complete() - self.current_frame = None + for current_frame in self.current_frames.values(): + self.log.warning("Flushed transmit frame during reset: %s", current_frame) + current_frame.handle_tx_complete() + self.current_frames = {} async def _run(self): - frame = None - frame_offset = 0 + # next frame hold the most recently popped frame from the Queue + # It may be held if the number of entries in frames is >= max_interleave_depth + next_frame = None + # Frames holds the in-flight frame for each of the interleaved stream + frames = {} + frame_offsets = {} + self.active = False has_tready = hasattr(self.bus, "tready") @@ -519,18 +546,36 @@ async def _run(self): tvalid_sample = (not has_tvalid) or self.bus.tvalid.value if (tready_sample and tvalid_sample) or not tvalid_sample: - if not frame and not self.queue.empty(): - frame = self.queue.get_nowait() - self.dequeue_event.set() - self.queue_occupancy_bytes -= len(frame) - self.queue_occupancy_frames -= 1 - self.current_frame = frame - frame.sim_time_start = get_sim_time() - frame.sim_time_end = None - self.log.info("TX frame: %s", frame) - frame.normalize() - self.active = True - frame_offset = 0 + + # Pop a frame from the queue if we have space + if not next_frame and not self.queue.empty(): + next_frame = self.queue.get_nowait() + + # Schedule the previously popped frame if that doesn't exceed our limits + if next_frame and (self.max_interleave_depth is None or len(frames) < self.max_interleave_depth): + k = (int(next_frame.tid) if "tid" in self.interleave else None, int(next_frame.tdest) if "tdest" in self.interleave else None) + if frames.get(k) == None: + frame = next_frame + next_frame = None + self.dequeue_event.set() + self.queue_occupancy_bytes -= len(frame) + self.queue_occupancy_frames -= 1 + self.current_frames[k] = frame + frame.sim_time_start = get_sim_time() + frame.sim_time_end = None + self.log.info("TX frame: %s", frame) + frame.normalize() + self.active = True + frames[k] = frame + frame_offsets[k] = 0 + + frame = None + k = None + frame_offset = 0 + if frames: + k = choice(list(frames.keys())) + frame = frames[k] + frame_offset = frame_offsets[k] if frame and not self.pause: tdata_val = 0 @@ -547,15 +592,17 @@ async def _run(self): tdest_val = frame.tdest[frame_offset] tuser_val = frame.tuser[frame_offset] frame_offset += 1 + frame_offsets[k] = frame_offset if frame_offset >= len(frame.tdata): tlast_val = 1 frame.sim_time_end = get_sim_time() frame.handle_tx_complete() - frame = None - self.current_frame = None + del frames[k] + del self.current_frames[k] + del frame_offsets[k] break - + self.bus.tdata.value = tdata_val if has_tvalid: self.bus.tvalid.value = 1 @@ -574,8 +621,8 @@ async def _run(self): self.bus.tvalid.value = 0 if has_tlast: self.bus.tlast.value = 0 - self.active = bool(frame) - if not frame and self.queue.empty(): + self.active = bool(frames) + if not frames and self.empty(): self.idle_event.set() self.active_event.clear() @@ -592,9 +639,9 @@ class AxiStreamMonitor(AxiStreamBase): _ready_init = None def __init__(self, bus, clock, reset=None, reset_active_level=True, - byte_size=None, byte_lanes=None, *args, **kwargs): + byte_size=None, byte_lanes=None, interleave=None, *args, **kwargs): - super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, *args, **kwargs) + super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, interleave, *args, **kwargs) self.read_queue = [] @@ -666,7 +713,7 @@ async def _run_tready_monitor(self): self.wake_event.set() async def _run(self): - frame = None + frames = {} self.active = False has_tready = hasattr(self.bus, "tready") @@ -689,6 +736,9 @@ async def _run(self): tvalid_sample = (not has_tvalid) or self.bus.tvalid.value if tready_sample and tvalid_sample: + k = (int(self.bus.tid.value) if "tid" in self.interleave else None, int(self.bus.tdest.value) if "tdest" in self.interleave else None) + frame = frames.pop(k, None) + if not frame: if self.byte_size == 8: frame = AxiStreamFrame(bytearray(), [], [], [], []) @@ -717,8 +767,8 @@ async def _run(self): self.queue.put_nowait(frame) self.active_event.set() - - frame = None + else: + frames[k] = frame else: self.active = bool(frame) @@ -736,12 +786,12 @@ class AxiStreamSink(AxiStreamMonitor, AxiStreamPause): _ready_init = 0 def __init__(self, bus, clock, reset=None, reset_active_level=True, - byte_size=None, byte_lanes=None, *args, **kwargs): + byte_size=None, byte_lanes=None, interleave=None, *args, **kwargs): self.queue_occupancy_limit_bytes = -1 self.queue_occupancy_limit_frames = -1 - super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, *args, **kwargs) + super().__init__(bus, clock, reset, reset_active_level, byte_size, byte_lanes, interleave, *args, **kwargs) def full(self): if self.queue_occupancy_limit_bytes > 0 and self.queue_occupancy_bytes > self.queue_occupancy_limit_bytes: @@ -765,7 +815,7 @@ def _dequeue(self, frame): self.wake_event.set() async def _run(self): - frame = None + frames = {} self.active = False has_tready = hasattr(self.bus, "tready") @@ -790,6 +840,9 @@ async def _run(self): tvalid_sample = (not has_tvalid) or self.bus.tvalid.value if tready_sample and tvalid_sample: + k = (int(self.bus.tid.value) if "tid" in self.interleave else None, int(self.bus.tdest.value) if "tdest" in self.interleave else None) + frame = frames.pop(k, None) + if not frame: if self.byte_size == 8: frame = AxiStreamFrame(bytearray(), [], [], [], []) @@ -818,10 +871,10 @@ async def _run(self): self.queue.put_nowait(frame) self.active_event.set() - - frame = None + else: + frames[k] = frame else: - self.active = bool(frame) + self.active = reduce(lambda r, f: r or bool(f), frames, False) if has_tready: self.bus.tready.value = (not self.full() and not pause_sample) diff --git a/tests/axis/test_axis.py b/tests/axis/test_axis.py index 7ed0aab..18cd50f 100644 --- a/tests/axis/test_axis.py +++ b/tests/axis/test_axis.py @@ -39,7 +39,7 @@ class TB: - def __init__(self, dut): + def __init__(self, dut, interleave): self.dut = dut self.log = logging.getLogger("cocotb.tb") @@ -47,9 +47,9 @@ def __init__(self, dut): cocotb.start_soon(Clock(dut.clk, 2, units="ns").start()) - self.source = AxiStreamSource(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst) - self.sink = AxiStreamSink(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst) - self.monitor = AxiStreamMonitor(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst) + self.source = AxiStreamSource(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst, interleave=interleave) + self.sink = AxiStreamSink(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst, interleave=interleave) + self.monitor = AxiStreamMonitor(AxiStreamBus.from_prefix(dut, "axis"), dut.clk, dut.rst, interleave=interleave) def set_idle_generator(self, generator=None): if generator: @@ -71,9 +71,9 @@ async def reset(self): await RisingEdge(self.dut.clk) -async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=None, backpressure_inserter=None): +async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=None, backpressure_inserter=None, interleave=None): - tb = TB(dut) + tb = TB(dut,interleave) id_count = 2**len(tb.source.bus.tid) @@ -141,6 +141,7 @@ def incrementing_payload(length): factory.add_option("payload_lengths", [size_list]) factory.add_option("payload_data", [incrementing_payload]) factory.add_option("idle_inserter", [None, cycle_pause]) + factory.add_option("interleave", [None, "tid", "tdest", {"tid", "tdest"} ]) factory.add_option("backpressure_inserter", [None, cycle_pause]) factory.generate_tests()