Skip to content

Commit

Permalink
Multihead support for ITA (#16)
Browse files Browse the repository at this point in the history
* [feature] MHSA for ITA

Features
- Added configuration file for MemPool with ITA

Changes
- ITA: Load RQS parameter from memory
- ITA: Support user-specified output address for results
- ITA: Support different matrix shapes
- ITA: Support 32-bit row-vise biases
- ITA: Fetch Q and K always from ITA Core 0

Fix
- ITA: Fix overflow bug in streaming_partial_softmax

Important Note
The softmax values have a maximum value of 127 as `sumdot` modules of the hardware can only do signed-signed operations for now. This is a temporary fix until `sumdot` is fixed.
  • Loading branch information
Xeratec authored Feb 28, 2024
1 parent cb45f43 commit 29607d2
Show file tree
Hide file tree
Showing 4 changed files with 468 additions and 275 deletions.
8 changes: 2 additions & 6 deletions config/mempool.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,9 @@ memory:
latency: 5
callbacks:
- name: zero-memory
size: 0x40
- name: mempool-ita
size: 32
- name: zero-memory
size: 0xFFA0
size: 0x10000
- name: mempool-dma
size: 28
size: 0x1C
inst_latency:
mul: 3
mulh: 3
Expand Down
57 changes: 57 additions & 0 deletions config/mempool_ita.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2021 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0

---
address:
scratch_reg: 0x40000000
wakeup_reg: 0x40000004
tcdm_start: 0x40000008
tcdm_end: 0x4000000C
nr_cores: 0x40000010
uart: 0xC0000000
# Not supported in MemPool
barrier_reg:
start: 0x50000000
offset: 0x100000
cluster_base_hartid: 0x50000001
cluster_num: 0x50000002
cluster_id: 0x50000003
cl_clint: 0x40000060
clint: 0xFFFF0000
memory:
tcdm:
start: 0x0
size: 0x100000
offset: 0x100000
latency: 5
dram:
start: 0x80000000
size: 0x01000000
offset: 0x0
latency: 10
periphs:
start: 0x40000000
size: 0x20000
offset: 0x0
latency: 5
callbacks:
- name: zero-memory
size: 0x40
- name: mempool-ita
size: 0xC0
- name: zero-memory
size: 0xFF00
- name: mempool-dma
size: 0x1C
inst_latency:
mul: 3
mulh: 3
mulhsu: 3
mulhu: 3
div: 3
divu: 3
rem: 3
remu: 3
ssr:
num_dm: 3
22 changes: 10 additions & 12 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ impl Engine {
+ self.config.memory.tcdm.offset * i as u32
+ self.config.memory.tcdm.size)
{
debug!("Entering TCDM allocation section.");
debug!(
trace!("Entering TCDM allocation section.");
trace!(
"Writing value into position: 0x{:x}",
((addr
- ((self.config.memory.tcdm.start
Expand Down Expand Up @@ -756,8 +756,8 @@ impl<'a, 'b> Cpu<'a, 'b> {
+ self.engine.config.memory.tcdm.size)
}) =>
{
debug!("TCDM Binary Load");
debug!("Binary load address: 0x{:x}", x);
trace!("TCDM Binary Load");
trace!("Binary load address: 0x{:x}", x);
let id = (0..self.engine.num_clusters)
.position(|i| {
addr >= (self.engine.config.memory.tcdm.start
Expand Down Expand Up @@ -786,8 +786,8 @@ impl<'a, 'b> Cpu<'a, 'b> {
+ self.engine.config.memory.periphs.size)
}) =>
{
debug!("Peripheral Binary Load");
debug!("Binary load address: 0x{:x}", x);
trace!("Peripheral Binary Load");
trace!("Binary load address: 0x{:x}", x);
let id = (0..self.engine.num_clusters)
.position(|i| {
addr >= (self.engine.config.memory.periphs.start
Expand Down Expand Up @@ -898,8 +898,8 @@ impl<'a, 'b> Cpu<'a, 'b> {
+ self.engine.config.memory.tcdm.size)
}) =>
{
debug!("TCDM Binary Store");
debug!("Binary store address: 0x{:x}", x);
trace!("TCDM Binary Store");
trace!("Binary store address: 0x{:x}", x);
let id = (0..self.engine.num_clusters)
.position(|i| {
addr >= (self.engine.config.memory.tcdm.start
Expand Down Expand Up @@ -933,8 +933,8 @@ impl<'a, 'b> Cpu<'a, 'b> {
+ self.engine.config.memory.periphs.size)
}) =>
{
debug!("Peripheral Binary store");
debug!("Binary store address: 0x{:x}", x);
trace!("Peripheral Binary store");
trace!("Binary store address: 0x{:x}", x);
let id = (0..self.engine.num_clusters)
.position(|i| {
addr >= (self.engine.config.memory.periphs.start
Expand Down Expand Up @@ -1059,7 +1059,6 @@ impl<'a, 'b> Cpu<'a, 'b> {
// n in bytes
trace!("MEMCPY From {:08x} to {:08x} num: {:08x}", src, dest, n);
if dest % 4 == 0 && src % 4 == 0 && n % 4 == 0 {
warn!("MEMCPY aligned");
// Aligned transfer
for _ in 0..n / 4 {
let tmp = self.binary_load(src, 2);
Expand All @@ -1068,7 +1067,6 @@ impl<'a, 'b> Cpu<'a, 'b> {
dest += 4;
}
} else {
warn!("MEMCPY unaligned");
for _ in 0..n {
let tmp = self.binary_load(src, 0);
self.binary_store(dest, tmp, (u8::MAX as u32) << (8 * (dest % 4)), 0);
Expand Down
Loading

0 comments on commit 29607d2

Please sign in to comment.