Skip to content

Commit

Permalink
wishbone.bus: add Arbiter.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-François Nguyen authored Feb 12, 2020
1 parent f8f8982 commit 967a65f
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 1 deletion.
265 changes: 265 additions & 0 deletions nmigen_soc/test/test_wishbone_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,268 @@ def sim_test():
with Simulator(m, vcd_file=open("test.vcd", "w")) as sim:
sim.add_process(sim_test())
sim.run()


class ArbiterTestCase(unittest.TestCase):
def setUp(self):
self.dut = Arbiter(addr_width=31, data_width=32, granularity=16,
features={"err"})

def test_add_wrong(self):
with self.assertRaisesRegex(TypeError,
r"Initiator bus must be an instance of wishbone\.Interface, not 'foo'"):
self.dut.add("foo")

def test_add_wrong_addr_width(self):
with self.assertRaisesRegex(ValueError,
r"Initiator bus has address width 15, which is not the same as arbiter "
r"address width 31"):
self.dut.add(Interface(addr_width=15, data_width=32, granularity=16))

def test_add_wrong_granularity(self):
with self.assertRaisesRegex(ValueError,
r"Initiator bus has granularity 8, which is lesser than "
r"the arbiter granularity 16"):
self.dut.add(Interface(addr_width=31, data_width=32, granularity=8))

def test_add_wrong_data_width(self):
with self.assertRaisesRegex(ValueError,
r"Initiator bus has data width 16, which is not the same as arbiter "
r"data width 32"):
self.dut.add(Interface(addr_width=31, data_width=16, granularity=16))

def test_add_wrong_optional_output(self):
with self.assertRaisesRegex(ValueError,
r"Arbiter has optional output 'err', but the initiator bus does "
r"not have a corresponding input"):
self.dut.add(Interface(addr_width=31, data_width=32, granularity=16))


class ArbiterSimulationTestCase(unittest.TestCase):
def test_simple(self):
dut = Arbiter(addr_width=30, data_width=32, granularity=8,
features={"err", "rty", "stall", "lock", "cti", "bte"})
intr_1 = Interface(addr_width=30, data_width=32, granularity=8,
features={"err", "rty"})
dut.add(intr_1)
intr_2 = Interface(addr_width=30, data_width=32, granularity=16,
features={"err", "rty", "stall", "lock", "cti", "bte"})
dut.add(intr_2)

def sim_test():
yield intr_1.adr.eq(0x7ffffffc >> 2)
yield intr_1.cyc.eq(1)
yield intr_1.stb.eq(1)
yield intr_1.sel.eq(0b1111)
yield intr_1.we.eq(1)
yield intr_1.dat_w.eq(0x12345678)
yield dut.bus.dat_r.eq(0xabcdef01)
yield dut.bus.ack.eq(1)
yield dut.bus.err.eq(1)
yield dut.bus.rty.eq(1)
yield Delay(1e-7)
self.assertEqual((yield dut.bus.adr), 0x7ffffffc >> 2)
self.assertEqual((yield dut.bus.cyc), 1)
self.assertEqual((yield dut.bus.stb), 1)
self.assertEqual((yield dut.bus.sel), 0b1111)
self.assertEqual((yield dut.bus.we), 1)
self.assertEqual((yield dut.bus.dat_w), 0x12345678)
self.assertEqual((yield dut.bus.lock), 1)
self.assertEqual((yield dut.bus.cti), CycleType.CLASSIC.value)
self.assertEqual((yield dut.bus.bte), BurstTypeExt.LINEAR.value)
self.assertEqual((yield intr_1.dat_r), 0xabcdef01)
self.assertEqual((yield intr_1.ack), 1)
self.assertEqual((yield intr_1.err), 1)
self.assertEqual((yield intr_1.rty), 1)

yield intr_1.cyc.eq(0)
yield intr_2.adr.eq(0xe0000000 >> 2)
yield intr_2.cyc.eq(1)
yield intr_2.stb.eq(1)
yield intr_2.sel.eq(0b10)
yield intr_2.we.eq(1)
yield intr_2.dat_w.eq(0x43218765)
yield intr_2.lock.eq(0)
yield intr_2.cti.eq(CycleType.INCR_BURST)
yield intr_2.bte.eq(BurstTypeExt.WRAP_4)
yield Tick()

yield dut.bus.stall.eq(0)
yield Delay(1e-7)
self.assertEqual((yield dut.bus.adr), 0xe0000000 >> 2)
self.assertEqual((yield dut.bus.cyc), 1)
self.assertEqual((yield dut.bus.stb), 1)
self.assertEqual((yield dut.bus.sel), 0b1100)
self.assertEqual((yield dut.bus.we), 1)
self.assertEqual((yield dut.bus.dat_w), 0x43218765)
self.assertEqual((yield dut.bus.lock), 0)
self.assertEqual((yield dut.bus.cti), CycleType.INCR_BURST.value)
self.assertEqual((yield dut.bus.bte), BurstTypeExt.WRAP_4.value)
self.assertEqual((yield intr_2.dat_r), 0xabcdef01)
self.assertEqual((yield intr_2.ack), 1)
self.assertEqual((yield intr_2.err), 1)
self.assertEqual((yield intr_2.rty), 1)
self.assertEqual((yield intr_2.stall), 0)

with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
sim.add_clock(1e-6)
sim.add_sync_process(sim_test())
sim.run()

def test_lock(self):
dut = Arbiter(addr_width=30, data_width=32, features={"lock"})
intr_1 = Interface(addr_width=30, data_width=32, features={"lock"})
dut.add(intr_1)
intr_2 = Interface(addr_width=30, data_width=32, features={"lock"})
dut.add(intr_2)

def sim_test():
yield intr_1.cyc.eq(1)
yield intr_1.lock.eq(1)
yield intr_2.cyc.eq(1)
yield dut.bus.ack.eq(1)
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 1)
self.assertEqual((yield intr_2.ack), 0)

yield Tick()
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 1)
self.assertEqual((yield intr_2.ack), 0)

yield intr_1.lock.eq(0)
yield Tick()
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 0)
self.assertEqual((yield intr_2.ack), 1)

yield intr_2.cyc.eq(0)
yield Tick()
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 1)
self.assertEqual((yield intr_2.ack), 0)

yield intr_1.stb.eq(1)
yield Tick()
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 1)
self.assertEqual((yield intr_2.ack), 0)

yield intr_1.stb.eq(0)
yield intr_2.cyc.eq(1)
yield Tick()
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 0)
self.assertEqual((yield intr_2.ack), 1)

with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
sim.add_clock(1e-6)
sim.add_sync_process(sim_test())
sim.run()

def test_stall(self):
dut = Arbiter(addr_width=30, data_width=32, features={"stall"})
intr_1 = Interface(addr_width=30, data_width=32, features={"stall"})
dut.add(intr_1)
intr_2 = Interface(addr_width=30, data_width=32, features={"stall"})
dut.add(intr_2)

def sim_test():
yield intr_1.cyc.eq(1)
yield intr_2.cyc.eq(1)
yield dut.bus.stall.eq(0)
yield Delay(1e-6)
self.assertEqual((yield intr_1.stall), 0)
self.assertEqual((yield intr_2.stall), 1)

yield dut.bus.stall.eq(1)
yield Delay(1e-6)
self.assertEqual((yield intr_1.stall), 1)
self.assertEqual((yield intr_2.stall), 1)

with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
sim.add_process(sim_test())
sim.run()

def test_stall_compat(self):
dut = Arbiter(addr_width=30, data_width=32)
intr_1 = Interface(addr_width=30, data_width=32, features={"stall"})
dut.add(intr_1)
intr_2 = Interface(addr_width=30, data_width=32, features={"stall"})
dut.add(intr_2)

def sim_test():
yield intr_1.cyc.eq(1)
yield intr_2.cyc.eq(1)
yield Delay(1e-6)
self.assertEqual((yield intr_1.stall), 1)
self.assertEqual((yield intr_2.stall), 1)

yield dut.bus.ack.eq(1)
yield Delay(1e-6)
self.assertEqual((yield intr_1.stall), 0)
self.assertEqual((yield intr_2.stall), 1)

with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
sim.add_process(sim_test())
sim.run()

def test_roundrobin(self):
dut = Arbiter(addr_width=30, data_width=32)
intr_1 = Interface(addr_width=30, data_width=32)
dut.add(intr_1)
intr_2 = Interface(addr_width=30, data_width=32)
dut.add(intr_2)
intr_3 = Interface(addr_width=30, data_width=32)
dut.add(intr_3)

def sim_test():
yield intr_1.cyc.eq(1)
yield intr_2.cyc.eq(0)
yield intr_3.cyc.eq(1)
yield dut.bus.ack.eq(1)
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 1)
self.assertEqual((yield intr_2.ack), 0)
self.assertEqual((yield intr_3.ack), 0)

yield intr_1.cyc.eq(0)
yield intr_2.cyc.eq(0)
yield intr_3.cyc.eq(1)
yield Tick()
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 0)
self.assertEqual((yield intr_2.ack), 0)
self.assertEqual((yield intr_3.ack), 1)

yield intr_1.cyc.eq(1)
yield intr_2.cyc.eq(1)
yield intr_3.cyc.eq(0)
yield Tick()
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 1)
self.assertEqual((yield intr_2.ack), 0)
self.assertEqual((yield intr_3.ack), 0)

yield intr_1.cyc.eq(0)
yield intr_2.cyc.eq(1)
yield intr_3.cyc.eq(1)
yield Tick()
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 0)
self.assertEqual((yield intr_2.ack), 1)
self.assertEqual((yield intr_3.ack), 0)

yield intr_1.cyc.eq(1)
yield intr_2.cyc.eq(0)
yield intr_3.cyc.eq(1)
yield Tick()
yield Delay(1e-7)
self.assertEqual((yield intr_1.ack), 0)
self.assertEqual((yield intr_2.ack), 0)
self.assertEqual((yield intr_3.ack), 1)

with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
sim.add_clock(1e-6)
sim.add_sync_process(sim_test())
sim.run()
119 changes: 118 additions & 1 deletion nmigen_soc/wishbone/bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..memory import MemoryMap


__all__ = ["CycleType", "BurstTypeExt", "Interface", "Decoder"]
__all__ = ["CycleType", "BurstTypeExt", "Interface", "Decoder", "Arbiter"]


class CycleType(Enum):
Expand Down Expand Up @@ -311,3 +311,120 @@ def elaborate(self, platform):
m.d.comb += self.bus.stall.eq(stall_fanin)

return m


class Arbiter(Elaboratable):
"""Wishbone bus arbiter.
A round-robin arbiter for initiators accessing a shared Wishbone bus.
Parameters
----------
addr_width : int
Address width. See :class:`Interface`.
data_width : int
Data width. See :class:`Interface`.
granularity : int
Granularity. See :class:`Interface`
features : iter(str)
Optional signal set. See :class:`Interface`.
Attributes
----------
bus : :class:`Interface`
Shared Wishbone bus.
"""
def __init__(self, *, addr_width, data_width, granularity=None, features=frozenset()):
self.bus = Interface(addr_width=addr_width, data_width=data_width,
granularity=granularity, features=features)
self._intrs = []

def add(self, intr_bus):
"""Add an initiator bus to the arbiter.
The initiator bus must have the same address width and data width as the arbiter. The
granularity of the initiator bus must be greater than or equal to the granularity of
the arbiter.
"""
if not isinstance(intr_bus, Interface):
raise TypeError("Initiator bus must be an instance of wishbone.Interface, not {!r}"
.format(intr_bus))
if intr_bus.addr_width != self.bus.addr_width:
raise ValueError("Initiator bus has address width {}, which is not the same as "
"arbiter address width {}"
.format(intr_bus.addr_width, self.bus.addr_width))
if intr_bus.granularity < self.bus.granularity:
raise ValueError("Initiator bus has granularity {}, which is lesser than the "
"arbiter granularity {}"
.format(intr_bus.granularity, self.bus.granularity))
if intr_bus.data_width != self.bus.data_width:
raise ValueError("Initiator bus has data width {}, which is not the same as "
"arbiter data width {}"
.format(intr_bus.data_width, self.bus.data_width))
for opt_output in {"err", "rty"}:
if hasattr(self.bus, opt_output) and not hasattr(intr_bus, opt_output):
raise ValueError("Arbiter has optional output {!r}, but the initiator bus "
"does not have a corresponding input"
.format(opt_output))

self._intrs.append(intr_bus)

def elaborate(self, platform):
m = Module()

requests = Signal(len(self._intrs))
grant = Signal(range(len(self._intrs)))
m.d.comb += requests.eq(Cat(intr_bus.cyc for intr_bus in self._intrs))

bus_busy = self.bus.cyc
if hasattr(self.bus, "lock"):
# If LOCK is not asserted, we also wait for STB to be deasserted before granting bus
# ownership to the next initiator. If we didn't, the next bus owner could receive
# an ACK (or ERR, RTY) from the previous transaction when targeting the same
# peripheral.
bus_busy &= self.bus.lock | self.bus.stb

with m.If(~bus_busy):
with m.Switch(grant):
for i in range(len(requests)):
with m.Case(i):
for pred in reversed(range(i)):
with m.If(requests[pred]):
m.d.sync += grant.eq(pred)
for succ in reversed(range(i + 1, len(requests))):
with m.If(requests[succ]):
m.d.sync += grant.eq(succ)

with m.Switch(grant):
for i, intr_bus in enumerate(self._intrs):
m.d.comb += intr_bus.dat_r.eq(self.bus.dat_r)
if hasattr(intr_bus, "stall"):
intr_bus_stall = Signal(reset=1)
m.d.comb += intr_bus.stall.eq(intr_bus_stall)

with m.Case(i):
ratio = intr_bus.granularity // self.bus.granularity
m.d.comb += [
self.bus.adr.eq(intr_bus.adr),
self.bus.dat_w.eq(intr_bus.dat_w),
self.bus.sel.eq(Cat(Repl(sel, ratio) for sel in intr_bus.sel)),
self.bus.we.eq(intr_bus.we),
self.bus.stb.eq(intr_bus.stb),
]
m.d.comb += self.bus.cyc.eq(intr_bus.cyc)
if hasattr(self.bus, "lock"):
m.d.comb += self.bus.lock.eq(getattr(intr_bus, "lock", 1))
if hasattr(self.bus, "cti"):
m.d.comb += self.bus.cti.eq(getattr(intr_bus, "cti", CycleType.CLASSIC))
if hasattr(self.bus, "bte"):
m.d.comb += self.bus.bte.eq(getattr(intr_bus, "bte", BurstTypeExt.LINEAR))

m.d.comb += intr_bus.ack.eq(self.bus.ack)
if hasattr(intr_bus, "err"):
m.d.comb += intr_bus.err.eq(getattr(self.bus, "err", 0))
if hasattr(intr_bus, "rty"):
m.d.comb += intr_bus.rty.eq(getattr(self.bus, "rty", 0))
if hasattr(intr_bus, "stall"):
m.d.comb += intr_bus_stall.eq(getattr(self.bus, "stall", ~self.bus.ack))

return m

0 comments on commit 967a65f

Please sign in to comment.