Skip to content

Commit

Permalink
Apply black & isort formatting
Browse files Browse the repository at this point in the history
Internal-tag: [#68702]
Signed-off-by: Wiktoria Kuna <[email protected]>
  • Loading branch information
wkkuna committed Dec 2, 2024
1 parent c514177 commit 236ebf9
Show file tree
Hide file tree
Showing 47 changed files with 2,894 additions and 2,103 deletions.
361 changes: 202 additions & 159 deletions rowhammer_tester/gateware/bist.py

Large diffs are not rendered by default.

344 changes: 209 additions & 135 deletions rowhammer_tester/gateware/payload_executor.py

Large diffs are not rendered by default.

38 changes: 21 additions & 17 deletions rowhammer_tester/gateware/rowhammer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from litex.soc.integration.doc import AutoDoc, ModuleDoc
from litex.soc.interconnect.csr import AutoCSR, CSRStatus, CSRStorage
from migen import *

from litex.soc.interconnect.csr import CSRStatus, CSRStorage, AutoCSR
from litex.soc.integration.doc import AutoDoc, ModuleDoc

class RowHammerDMA(Module, AutoCSR, AutoDoc, ModuleDoc):
"""
Expand All @@ -13,33 +13,37 @@ class RowHammerDMA(Module, AutoCSR, AutoDoc, ModuleDoc):
result in the DRAM controller having to repeatedly open/close rows at each
read access.
"""

def __init__(self, dma):
address_width = len(dma.sink.address)

self.enabled = CSRStorage(description="Used to start/stop the operation of the module")
self.enabled = CSRStorage(description="Used to start/stop the operation of the module")
self.address1 = CSRStorage(address_width, description="First attacked address")
self.address2 = CSRStorage(address_width, description="Second attacked address")
self.count = CSRStatus(32, description="""This is the number of DMA accesses performed.
self.count = CSRStatus(
32,
description="""This is the number of DMA accesses performed.
When the module is enabled, the value can be freely read. When
the module is disabled, the register is clear-on-write and has
to be read before the next attack.""")
to be read before the next attack.""",
)

counter = Signal.like(self.count.status)
self.comb += self.count.status.eq(counter)
self.sync += \
If(self.enabled.storage,
If(dma.sink.valid & dma.sink.ready,
counter.eq(counter + 1)
)
).Elif(self.count.we, # clear on read when not enabled
counter.eq(0)
)
self.sync += If(
self.enabled.storage, If(dma.sink.valid & dma.sink.ready, counter.eq(counter + 1))
).Elif(
self.count.we, counter.eq(0) # clear on read when not enabled
)

address = Signal(address_width)
self.comb += Case(counter[0], {
0: address.eq(self.address1.storage),
1: address.eq(self.address2.storage),
})
self.comb += Case(
counter[0],
{
0: address.eq(self.address1.storage),
1: address.eq(self.address2.storage),
},
)

self.comb += [
dma.sink.address.eq(address),
Expand Down
33 changes: 17 additions & 16 deletions rowhammer_tester/payload/ddr3lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
import collections
import math

import payload_ddr3_pb2

Expand All @@ -10,7 +10,7 @@


def VerifyInstr(ip: int, instr: Instr) -> bool:
if instr.HasField('mem'):
if instr.HasField("mem"):
mem = instr.mem
if mem.opcode not in {Opcode.RD, Opcode.ACT, Opcode.PRE, Opcode.REF}:
return False
Expand All @@ -26,14 +26,14 @@ def VerifyInstr(ip: int, instr: Instr) -> bool:
# We only ever want sequential (non-permuted) bursts.
return False
return True
if instr.HasField('nop'):
if instr.HasField("nop"):
nop = instr.nop
if nop.opcode != Opcode.NOP:
return False
if not (0 < nop.timeslice < (1 << Instr.NopInstr.Bits.TIMESLICE)):
return False
return True
if instr.HasField('jmp'):
if instr.HasField("jmp"):
jmp = instr.jmp
if jmp.opcode != Opcode.JMP:
return False
Expand All @@ -48,7 +48,6 @@ def VerifyInstr(ip: int, instr: Instr) -> bool:


class Rank:

def __init__(self, timing: Timing):
self.parameters = {
Opcode.RD: {
Expand Down Expand Up @@ -79,17 +78,21 @@ def __init__(self, timing: Timing):
def Execute(self, tick: int, instr: Instr.MemInstr) -> bool:
if tick < self.next_tick.get(instr.opcode, 0):
print(
'Rank timing violation for {}: {} < {}'.format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]))
"Rank timing violation for {}: {} < {}".format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]
)
)
return False

# Special-case handling for tFAW.
if instr.opcode == Opcode.ACT:
if len(self.prev_acts) == self.prev_acts.maxlen:
if tick - self.prev_acts[0] < self.faw:
print(
'tFAW timing violation for {}: {} < {}'.format(
Opcode.Name(instr.opcode), tick - self.prev_acts[0], self.faw))
"tFAW timing violation for {}: {} < {}".format(
Opcode.Name(instr.opcode), tick - self.prev_acts[0], self.faw
)
)
return False
self.prev_acts.append(tick)

Expand All @@ -105,7 +108,6 @@ def Execute(self, tick: int, instr: Instr.MemInstr) -> bool:


class Bank:

def __init__(self, timing: Timing):
self.parameters = {
Opcode.RD: {
Expand All @@ -116,18 +118,17 @@ def __init__(self, timing: Timing):
Opcode.ACT: math.inf,
Opcode.PRE: timing.ras,
},
Opcode.PRE: {
Opcode.RD: math.inf,
Opcode.ACT: timing.rp
}
Opcode.PRE: {Opcode.RD: math.inf, Opcode.ACT: timing.rp},
}
self.next_tick = {Opcode.RD: math.inf, Opcode.ACT: 0, Opcode.PRE: 0, Opcode.REF: 0}

def Execute(self, tick: int, instr: Instr.MemInstr) -> bool:
if tick < self.next_tick.get(instr.opcode, 0):
print(
'Bank timing violation for {}: {} < {}'.format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]))
"Bank timing violation for {}: {} < {}".format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]
)
)
return False

for opcode, parameter in self.parameters.get(instr.opcode, {}).items():
Expand Down
48 changes: 23 additions & 25 deletions rowhammer_tester/payload/ddr4lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
import collections
import math

import payload_ddr4_pb2

Expand All @@ -10,7 +10,7 @@


def VerifyInstr(ip: int, instr: Instr) -> bool:
if instr.HasField('mem'):
if instr.HasField("mem"):
mem = instr.mem
if mem.opcode not in {Opcode.RD, Opcode.ACT, Opcode.PRE, Opcode.REF}:
return False
Expand All @@ -30,14 +30,14 @@ def VerifyInstr(ip: int, instr: Instr) -> bool:
# We only ever want sequential (non-permuted) bursts.
return False
return True
if instr.HasField('nop'):
if instr.HasField("nop"):
nop = instr.nop
if nop.opcode != Opcode.NOP:
return False
if not (0 < nop.timeslice < (1 << Instr.NopInstr.Bits.TIMESLICE)):
return False
return True
if instr.HasField('jmp'):
if instr.HasField("jmp"):
jmp = instr.jmp
if jmp.opcode != Opcode.JMP:
return False
Expand All @@ -52,7 +52,6 @@ def VerifyInstr(ip: int, instr: Instr) -> bool:


class Rank:

def __init__(self, timing: Timing):
self.parameters = {
Opcode.ACT: {
Expand All @@ -78,17 +77,21 @@ def __init__(self, timing: Timing):
def Execute(self, tick: int, instr: Instr.MemInstr) -> bool:
if tick < self.next_tick.get(instr.opcode, 0):
print(
'Rank timing violation for {}: {} < {}'.format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]))
"Rank timing violation for {}: {} < {}".format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]
)
)
return False

# Special-case handling for tFAW.
if instr.opcode == Opcode.ACT:
if len(self.prev_acts) == self.prev_acts.maxlen:
if tick - self.prev_acts[0] < self.faw:
print(
'tFAW timing violation for {}: {} < {}'.format(
Opcode.Name(instr.opcode), tick - self.prev_acts[0], self.faw))
"tFAW timing violation for {}: {} < {}".format(
Opcode.Name(instr.opcode), tick - self.prev_acts[0], self.faw
)
)
return False
self.prev_acts.append(tick)

Expand All @@ -109,15 +112,10 @@ def Execute(self, tick: int, instr: Instr.MemInstr) -> bool:


class BankGroup:

def __init__(self, timing: Timing):
self.parameters = {
Opcode.RD: {
Opcode.RD: [timing.ccd_l, timing.ccd_s]
},
Opcode.ACT: {
Opcode.ACT: [timing.rrd_l, timing.rrd_s]
},
Opcode.RD: {Opcode.RD: [timing.ccd_l, timing.ccd_s]},
Opcode.ACT: {Opcode.ACT: [timing.rrd_l, timing.rrd_s]},
}
self.next_tick = {Opcode.RD: 0, Opcode.ACT: 0}

Expand All @@ -126,8 +124,10 @@ def __init__(self, timing: Timing):
def Execute(self, tick: int, instr: Instr.MemInstr) -> bool:
if tick < self.next_tick.get(instr.opcode, 0):
print(
'Bank group timing violation for {}: {} < {}'.format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]))
"Bank group timing violation for {}: {} < {}".format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]
)
)
return False

if not self.banks[instr.bank].Execute(tick, instr):
Expand All @@ -143,7 +143,6 @@ def Update(self, tick: int, instr: Instr.MemInstr):


class Bank:

def __init__(self, timing: Timing):
self.parameters = {
Opcode.RD: {
Expand All @@ -154,18 +153,17 @@ def __init__(self, timing: Timing):
Opcode.ACT: math.inf,
Opcode.PRE: timing.ras,
},
Opcode.PRE: {
Opcode.RD: math.inf,
Opcode.ACT: timing.rp
}
Opcode.PRE: {Opcode.RD: math.inf, Opcode.ACT: timing.rp},
}
self.next_tick = {Opcode.RD: math.inf, Opcode.ACT: 0, Opcode.PRE: 0, Opcode.REF: 0}

def Execute(self, tick: int, instr: Instr.MemInstr) -> bool:
if tick < self.next_tick.get(instr.opcode, 0):
print(
'Bank timing violation for {}: {} < {}'.format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]))
"Bank timing violation for {}: {} < {}".format(
Opcode.Name(instr.opcode), tick, self.next_tick[instr.opcode]
)
)
return False

for opcode, parameter in self.parameters.get(instr.opcode, {}).items():
Expand Down
Loading

0 comments on commit 236ebf9

Please sign in to comment.