From 29607d2115f95e76fa6491145d63ece4f26059ca Mon Sep 17 00:00:00 2001 From: Philip Wiese Date: Wed, 28 Feb 2024 18:00:18 +0100 Subject: [PATCH] Multihead support for ITA (#16) * [feature] MHSA for ITA Features - Added configuration file for MemPool with ITA Changes - ITA: Load RQS parameter from memory - ITA: Support user-specified output address for results - ITA: Support different matrix shapes - ITA: Support 32-bit row-vise biases - ITA: Fetch Q and K always from ITA Core 0 Fix - ITA: Fix overflow bug in streaming_partial_softmax Important Note The softmax values have a maximum value of 127 as `sumdot` modules of the hardware can only do signed-signed operations for now. This is a temporary fix until `sumdot` is fixed. --- config/mempool.yaml | 8 +- config/mempool_ita.yaml | 57 ++++ src/engine.rs | 22 +- src/peripherals.rs | 656 ++++++++++++++++++++++++---------------- 4 files changed, 468 insertions(+), 275 deletions(-) create mode 100644 config/mempool_ita.yaml diff --git a/config/mempool.yaml b/config/mempool.yaml index 58627d5..b9e776d 100644 --- a/config/mempool.yaml +++ b/config/mempool.yaml @@ -37,13 +37,9 @@ memory: latency: 5 callbacks: - name: zero-memory - size: 0x40 - - name: mempool-ita - size: 32 - - name: zero-memory - size: 0xFFA0 + size: 0x10000 - name: mempool-dma - size: 28 + size: 0x1C inst_latency: mul: 3 mulh: 3 diff --git a/config/mempool_ita.yaml b/config/mempool_ita.yaml new file mode 100644 index 0000000..3c26e30 --- /dev/null +++ b/config/mempool_ita.yaml @@ -0,0 +1,57 @@ +# Copyright 2021 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +--- +address: + scratch_reg: 0x40000000 + wakeup_reg: 0x40000004 + tcdm_start: 0x40000008 + tcdm_end: 0x4000000C + nr_cores: 0x40000010 + uart: 0xC0000000 + # Not supported in MemPool + barrier_reg: + start: 0x50000000 + offset: 0x100000 + cluster_base_hartid: 0x50000001 + cluster_num: 0x50000002 + cluster_id: 0x50000003 + cl_clint: 0x40000060 + clint: 0xFFFF0000 +memory: + tcdm: + start: 0x0 + size: 0x100000 + offset: 0x100000 + latency: 5 + dram: + start: 0x80000000 + size: 0x01000000 + offset: 0x0 + latency: 10 + periphs: + start: 0x40000000 + size: 0x20000 + offset: 0x0 + latency: 5 + callbacks: + - name: zero-memory + size: 0x40 + - name: mempool-ita + size: 0xC0 + - name: zero-memory + size: 0xFF00 + - name: mempool-dma + size: 0x1C +inst_latency: + mul: 3 + mulh: 3 + mulhsu: 3 + mulhu: 3 + div: 3 + divu: 3 + rem: 3 + remu: 3 +ssr: + num_dm: 3 diff --git a/src/engine.rs b/src/engine.rs index 08b7f19..216996b 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -380,8 +380,8 @@ impl Engine { + self.config.memory.tcdm.offset * i as u32 + self.config.memory.tcdm.size) { - debug!("Entering TCDM allocation section."); - debug!( + trace!("Entering TCDM allocation section."); + trace!( "Writing value into position: 0x{:x}", ((addr - ((self.config.memory.tcdm.start @@ -756,8 +756,8 @@ impl<'a, 'b> Cpu<'a, 'b> { + self.engine.config.memory.tcdm.size) }) => { - debug!("TCDM Binary Load"); - debug!("Binary load address: 0x{:x}", x); + trace!("TCDM Binary Load"); + trace!("Binary load address: 0x{:x}", x); let id = (0..self.engine.num_clusters) .position(|i| { addr >= (self.engine.config.memory.tcdm.start @@ -786,8 +786,8 @@ impl<'a, 'b> Cpu<'a, 'b> { + self.engine.config.memory.periphs.size) }) => { - debug!("Peripheral Binary Load"); - debug!("Binary load address: 0x{:x}", x); + trace!("Peripheral Binary Load"); + trace!("Binary load address: 0x{:x}", x); let id = (0..self.engine.num_clusters) .position(|i| { addr >= (self.engine.config.memory.periphs.start @@ -898,8 +898,8 @@ impl<'a, 'b> Cpu<'a, 'b> { + self.engine.config.memory.tcdm.size) }) => { - debug!("TCDM Binary Store"); - debug!("Binary store address: 0x{:x}", x); + trace!("TCDM Binary Store"); + trace!("Binary store address: 0x{:x}", x); let id = (0..self.engine.num_clusters) .position(|i| { addr >= (self.engine.config.memory.tcdm.start @@ -933,8 +933,8 @@ impl<'a, 'b> Cpu<'a, 'b> { + self.engine.config.memory.periphs.size) }) => { - debug!("Peripheral Binary store"); - debug!("Binary store address: 0x{:x}", x); + trace!("Peripheral Binary store"); + trace!("Binary store address: 0x{:x}", x); let id = (0..self.engine.num_clusters) .position(|i| { addr >= (self.engine.config.memory.periphs.start @@ -1059,7 +1059,6 @@ impl<'a, 'b> Cpu<'a, 'b> { // n in bytes trace!("MEMCPY From {:08x} to {:08x} num: {:08x}", src, dest, n); if dest % 4 == 0 && src % 4 == 0 && n % 4 == 0 { - warn!("MEMCPY aligned"); // Aligned transfer for _ in 0..n / 4 { let tmp = self.binary_load(src, 2); @@ -1068,7 +1067,6 @@ impl<'a, 'b> Cpu<'a, 'b> { dest += 4; } } else { - warn!("MEMCPY unaligned"); for _ in 0..n { let tmp = self.binary_load(src, 0); self.binary_store(dest, tmp, (u8::MAX as u32) << (8 * (dest % 4)), 0); diff --git a/src/peripherals.rs b/src/peripherals.rs index 2faf98f..405bd42 100644 --- a/src/peripherals.rs +++ b/src/peripherals.rs @@ -6,7 +6,7 @@ use crate::configuration::Callback; use crate::Cpu; use ndarray::{s, Array1, Array2, Array3}; -use std::sync::atomic::{AtomicI32, AtomicU32, Ordering}; +use std::sync::atomic::{AtomicU32, Ordering}; use PeriphReq::{Load, Store}; /// Reference held by execution engine, referencing each peripheral instance in each cluster @@ -306,99 +306,162 @@ impl Peripheral for MemPoolDMA { #[derive(Default)] struct MemPoolITA { - config: AtomicU32, - start_address: AtomicU32, - eps_mul_0: AtomicU32, - eps_mul_1: AtomicU32, - right_shift_0: AtomicU32, - right_shift_1: AtomicU32, - add_0: AtomicI32, - add_1: AtomicI32, + state: [AtomicU32; 4], + start_addr: [AtomicU32; 4], + out_addr: [AtomicU32; 4], + rqs_addr: [AtomicU32; 4], + seq_len: [AtomicU32; 4], + emb_len: [AtomicU32; 4], + proj_len: [AtomicU32; 4], } - impl Peripheral for MemPoolITA { /// should return the same name as in the config file fn get_name(&self) -> &'static str { "mempool-ita" } + /// store instruction fn store(&self, cpu: &Cpu, addr: u32, value: u32, _mask: u32, _size: u8) { + let i = addr as usize / 0x30; + let addr = addr as usize % 0x30; + match addr { 0x00 => unsafe { - self.config.store(value as u32, Ordering::SeqCst); + self.state[i].store(value as u32, Ordering::SeqCst); + info!( + "[ITA {}, CPU {}] Store state 0x{:02x}", + i, &cpu.hartid, value + ); // Out addresses are currently hardcoded in ITA - let out_addresses: [u32; 4] = [0x000C3000, 0x000D3000, 0x000E3000, 0x000F3000]; - let head_config = std::mem::transmute::(value); - let mut return_value = 0; - debug!("[ITA] Store config {:x}", value); - for (i, c) in head_config.iter().enumerate() { - if *c & 0x1 == 1 { - // Start ITA - self.run_ita( - cpu, - self.start_address.load(Ordering::SeqCst), - out_addresses[i], - self.eps_mul_0.load(Ordering::SeqCst), - self.eps_mul_1.load(Ordering::SeqCst), - self.right_shift_0.load(Ordering::SeqCst), - self.right_shift_1.load(Ordering::SeqCst), - self.add_0.load(Ordering::SeqCst), - self.add_1.load(Ordering::SeqCst), - ); - // Set `config` to done - return_value |= 0x1a << (8 * i); - } + let mut return_value = value; + if value & 0x1 == 1 { + // Start ITA + self.run_ita( + cpu, + self.start_addr[i].load(Ordering::SeqCst), + // All ITA cores fetch the Q and K vector always from the address specified to core 0 + self.start_addr[0].load(Ordering::SeqCst), + self.out_addr[i].load(Ordering::SeqCst), + self.rqs_addr[i].load(Ordering::SeqCst), + self.seq_len[i].load(Ordering::SeqCst), + self.emb_len[i].load(Ordering::SeqCst), + self.proj_len[i].load(Ordering::SeqCst), + 16, + ); + // Set busy flag + return_value |= 0x2; + // Clear start flag + return_value &= !0x1; + + self.state[i].store(return_value, Ordering::SeqCst); + info!("[ITA {}, CPU {}] Done.", i, &cpu.hartid); + info!( + "[ITA {}, CPU {}] Store state 0x{:02x}", + i, &cpu.hartid, return_value + ); } - self.config.store(return_value, Ordering::SeqCst); - debug!("[ITA] Save config {:x}", return_value); - }, - 0x04 => self.start_address.store(value as u32, Ordering::SeqCst), - 0x08 => self.eps_mul_0.store(value, Ordering::SeqCst), - 0x0C => self.eps_mul_1.store(value, Ordering::SeqCst), - 0x10 => self.right_shift_0.store(value, Ordering::SeqCst), - 0x14 => self.right_shift_1.store(value, Ordering::SeqCst), - 0x18 => unsafe { - self.add_0 - .store(std::mem::transmute::(value), Ordering::SeqCst) - }, - 0x1C => unsafe { - self.add_1 - .store(std::mem::transmute::(value), Ordering::SeqCst) }, + 0x04 => { + self.start_addr[i].store(value as u32, Ordering::SeqCst); + info!( + "[ITA {}, CPU {}] Store start address 0x{:08x}", + i, &cpu.hartid, value + ) + } + 0x08 => { + self.out_addr[i].store(value as u32, Ordering::SeqCst); + info!( + "[ITA {}, CPU {}] Store out address 0x{:08x}", + i, &cpu.hartid, value + ) + } + 0x0C => { + self.rqs_addr[i].store(value as u32, Ordering::SeqCst); + info!( + "[ITA {}, CPU {}] Store rqs_addr 0x{:016x}", + i, &cpu.hartid, value + ) + } + 0x10 => { + self.seq_len[i].store(value, Ordering::SeqCst); + info!( + "[ITA {}, CPU {}] Store seq_len 0x{:016x}", + i, &cpu.hartid, value + ) + } + 0x14 => { + self.emb_len[i].store(value, Ordering::SeqCst); + info!( + "[ITA {}, CPU {}] Store emb_len 0x{:016x}", + i, &cpu.hartid, value + ) + } + 0x18 => { + self.proj_len[i].store(value, Ordering::SeqCst); + info!( + "[ITA {}, CPU {}] Store proj_len 0x{:016x}", + i, &cpu.hartid, value + ) + } _ => unimplemented!(), } } /// load instruction - fn load(&self, _cpu: &Cpu, addr: u32, _size: u8) -> u32 { + fn load(&self, cpu: &Cpu, addr: u32, _size: u8) -> u32 { + let i = addr as usize / 0x30; + let addr = addr as usize % 0x30; + match addr { 0x00 => { - let conf = self.config.load(Ordering::SeqCst); - if conf == 0x1a1a1a1a { - self.config.store(0x04040404, Ordering::SeqCst); + let state = self.state[i].load(Ordering::SeqCst) & 0xFF; + info!( + "[ITA {}, CPU {}] Read state 0x{:02x}", + i, &cpu.hartid, state + ); + + let busy_flag = state & 0x02; + + // WIESEP: As we have no timing model, just set the done flag if the busy flag is set + if busy_flag == 0x02 { + let mut new_state = state; + // Clear the busy flag + new_state &= !0x02; + + // Set the done flag + new_state |= 0x4; + self.state[i].store(new_state, Ordering::SeqCst); + info!( + "[ITA {}, CPU {}] > ITA is done -> store state 0x{:02x}", + i, &cpu.hartid, new_state + ); } - conf + state } - 0x04 => self.start_address.load(Ordering::SeqCst), - 0x08 => self.eps_mul_0.load(Ordering::SeqCst), - 0x0C => self.eps_mul_1.load(Ordering::SeqCst), - 0x10 => self.right_shift_0.load(Ordering::SeqCst), - 0x14 => self.right_shift_1.load(Ordering::SeqCst), - 0x18 => unsafe { std::mem::transmute::(self.add_0.load(Ordering::SeqCst)) }, - 0x1C => unsafe { std::mem::transmute::(self.add_1.load(Ordering::SeqCst)) }, + 0x04 => self.start_addr[i].load(Ordering::SeqCst), + 0x08 => self.out_addr[i].load(Ordering::SeqCst), + 0x0C => self.rqs_addr[i].load(Ordering::SeqCst), + 0x10 => self.seq_len[i].load(Ordering::SeqCst), + 0x14 => self.emb_len[i].load(Ordering::SeqCst), + 0x18 => self.proj_len[i].load(Ordering::SeqCst), _ => unimplemented!(), } } } impl MemPoolITA { - fn transpose_3d(data: &mut Array3, m: u32, n: u32, p: u32) { - let copy = data.clone(); + fn transpose_2d_arrays(array: &mut Array3) -> Array3 + where + T: Clone, + { + return array.to_owned().permuted_axes([0, 2, 1]); + } + + unsafe fn ita_load_2d_i32(cpu: &Cpu, data: &mut Array2, mut address: u32, m: u32, n: u32) { for j in 0..m { for i in 0..n { - for h in 0..p { - data[[j as usize, i as usize, h as usize]] = - copy[[j as usize, h as usize, i as usize]]; - } + let word = cpu.binary_load(address, 2); + data[[j as usize, i as usize]] = word as i32; + address += 4; } } } @@ -471,18 +534,14 @@ impl MemPoolITA { elements[offset] = data[[j as usize, ((n / splits) * split + i) as usize + offset]] as u8; } - // let word = std::mem::transmute::<[u8; 4], u32>(elements); let word = u32::from_ne_bytes(elements); cpu.binary_store(address + address_offset, word, u32::MAX, 2); - for y in 0..4 { - let _stest = cpu.binary_load(address + address_offset + y, 2); - } - trace!("[ITA] Store OUT to 0x{:x}", address + address_offset); + debug!( + "[ITA, CPU {}] Store OUT to 0x{:x}", + &cpu.hartid, + address + address_offset + ); address_offset += 4; - // if address_offset % 0x100 == 0 { - // address_offset -= 0x0100; - // address_offset += 0x1000; - // } } } } @@ -492,95 +551,185 @@ impl MemPoolITA { &self, cpu: &Cpu, start_address: u32, + start_address_core0: u32, out_address: u32, - _eps_mult_0: u32, - _eps_mult_1: u32, - _right_shift_0: u32, - _right_shift_1: u32, - _add_0: i32, - _add_1: i32, + rqs_address: u32, + seq_len: u32, + emb_len: u32, + proj_len: u32, + processing_engines: u32, ) { - // TODO `eps_mult` and `right_shift` are currently hardcoded // Setup of matrices for query_projection_space_transformation and key_projection_space_transformation // Sequence of addresses are hardcoded let start = start_address; - let offset = 64 * 64; - let w4_addr = start + offset * 0; - let w3_addr = start + offset * 1; - let w2_addr = start + offset * 2; - let q_addr = start + offset * 3; - let k_addr = start + offset * 4; - let w1_addr = start + offset * 5; - let b4_addr = start + offset * 6; - let b3_addr = start + offset * 7; - let b2_addr = start + offset * 8; - let b1_addr = start + offset * 9; - - let rqs_mult = u64::to_le_bytes((_eps_mult_1 as u64) << 32 | (_eps_mult_0 as u64)); - let rqs_shift = u64::to_le_bytes((_right_shift_1 as u64) << 32 | (_right_shift_0 as u64)); - let rqs_add = u64::to_le_bytes((_add_1 as u64) << 32 | (_add_0 as u64)).map(|c| c as i8); + let w_o_addr = start; + let w_v_addr = start + proj_len * emb_len; + let w_k_addr = start + proj_len * emb_len * 2; + let w_q_addr = start + proj_len * emb_len * 3 + seq_len * emb_len * 2; + let b_o_addr = start + proj_len * emb_len * 4 + seq_len * emb_len * 2; + let b_v_addr = start + proj_len * emb_len * 4 + seq_len * emb_len * 2 + emb_len * 4; // 32 bit biases + let b_k_addr = + start + proj_len * emb_len * 4 + seq_len * emb_len * 2 + emb_len * 4 + proj_len * 4; // 32 bit biases + let b_q_addr = + start + proj_len * emb_len * 4 + seq_len * emb_len * 2 + emb_len * 4 + proj_len * 8; // 32 bit biases + + let q_addr = start_address_core0 + proj_len * emb_len * 3; + let k_addr = start_address_core0 + proj_len * emb_len * 3 + seq_len * emb_len; + + let mult_address = cpu.binary_load(rqs_address + 0x00, 2); + let shift_address = cpu.binary_load(rqs_address + 0x04, 2); + let add_address = cpu.binary_load(rqs_address + 0x08, 2); + + let rqs_mult_w1 = u32::to_ne_bytes(cpu.binary_load(mult_address + 0x00, 2)); + let rqs_mult_w2 = u32::to_ne_bytes(cpu.binary_load(mult_address + 0x04, 2)); + + let rqs_mult: [u8; 6] = [ + rqs_mult_w1[0], + rqs_mult_w1[1], + rqs_mult_w1[2], + rqs_mult_w1[3], + rqs_mult_w2[0], + rqs_mult_w2[1], + ]; + + let rqs_shift_w1 = u32::to_ne_bytes(cpu.binary_load(shift_address + 0x00, 2)); + let rqs_shift_w2: [u8; 4] = u32::to_ne_bytes(cpu.binary_load(shift_address + 0x04, 2)); + + let rqs_shift: [u8; 6] = [ + rqs_shift_w1[0], + rqs_shift_w1[1], + rqs_shift_w1[2], + rqs_shift_w1[3], + rqs_shift_w2[0], + rqs_shift_w2[1], + ]; + + let rqs_add: [i32; 6] = [ + cpu.binary_load(add_address + 0x00, 2) as i32, + cpu.binary_load(add_address + 0x04, 2) as i32, + cpu.binary_load(add_address + 0x08, 2) as i32, + cpu.binary_load(add_address + 0x0C, 2) as i32, + cpu.binary_load(add_address + 0x10, 2) as i32, + cpu.binary_load(add_address + 0x14, 2) as i32, + ]; + + debug!("[ITA, CPU {}] w_o_addr 0x{:x}", &cpu.hartid, w_o_addr); + debug!("[ITA, CPU {}] w_v_addr 0x{:x}", &cpu.hartid, w_v_addr); + debug!("[ITA, CPU {}] w_k_addr 0x{:x}", &cpu.hartid, w_k_addr); + debug!("[ITA, CPU {}] q_addr 0x{:x}", &cpu.hartid, q_addr); + debug!("[ITA, CPU {}] k_addr 0x{:x}", &cpu.hartid, k_addr); + debug!("[ITA, CPU {}] w_q_addr 0x{:x}", &cpu.hartid, w_q_addr); + debug!("[ITA, CPU {}] b_o_addr 0x{:x}", &cpu.hartid, b_o_addr); + debug!("[ITA, CPU {}] b_v_addr 0x{:x}", &cpu.hartid, b_v_addr); + debug!("[ITA, CPU {}] b_k_addr 0x{:x}", &cpu.hartid, b_k_addr); + debug!("[ITA, CPU {}] b_q_addr 0x{:x}", &cpu.hartid, b_q_addr); debug!( - "[ITA] Start Address 0x{:x}, Out Address 0x{:x}", - start, out_address + "[ITA, CPU {}] mult_address 0x{:x}", + &cpu.hartid, mult_address + ); + debug!( + "[ITA, CPU {}] shift_address 0x{:x}", + &cpu.hartid, shift_address + ); + debug!( + "[ITA, CPU {}] add_address 0x{:x}", + &cpu.hartid, add_address + ); + + let split_e = emb_len / processing_engines; + let split_p = proj_len / processing_engines; + + debug!( + "[ITA, CPU {}] Start Address 0x{:x}, Out Address 0x{:x}", + &cpu.hartid, start, out_address + ); + debug!("[ITA, CPU {}] RQS Mult {:?}", &cpu.hartid, rqs_mult); + debug!("[ITA, CPU {}] RQS Shift {:?}", &cpu.hartid, rqs_shift); + debug!("[ITA, CPU {}] RQS Add {:?}", &cpu.hartid, rqs_add); + debug!( + "[ITA, CPU {}] S {:?}, E {:?}, P {:?}", + &cpu.hartid, seq_len, emb_len, proj_len ); - debug!("[ITA] RQS Mult {:?}", rqs_mult); - debug!("[ITA] RQS Shift {:?}", rqs_shift); - debug!("[ITA] RQS Add {:?}", rqs_add); + debug!( + "[ITA, CPU {}] Split E {:?}, Split P {:?}", + &cpu.hartid, split_e, split_p + ); + + let mut q = Array2::::zeros((seq_len as usize, emb_len as usize)); + MemPoolITA::ita_load_2d(cpu, &mut q, q_addr, seq_len, emb_len, split_e); + debug!("[ITA, CPU {}] q.shape: {:?}", &cpu.hartid, q.shape()); + debug!("[ITA, CPU {}] q: {}", &cpu.hartid, q); - let mut q = Array2::::zeros((64, 64)); - MemPoolITA::ita_load_2d(cpu, &mut q, q_addr, 64, 64, 4); - let mut w_q = Array3::::zeros((1, 64, 64)); - MemPoolITA::ita_load_3d(cpu, &mut w_q, w1_addr, 1, 64, 64, 4); - MemPoolITA::transpose_3d(&mut w_q, 1, 64, 64); + let mut w_q = Array3::::zeros((1, proj_len as usize, emb_len as usize)); + MemPoolITA::ita_load_3d(cpu, &mut w_q, w_q_addr, 1, proj_len, emb_len, split_e); + w_q = MemPoolITA::transpose_2d_arrays(&mut w_q); + debug!("[ITA, CPU {}] w_q.shape: {:?}", &cpu.hartid, w_q.shape()); + debug!("[ITA, CPU {}] w_q: {}", &cpu.hartid, w_q); - let mut k = Array2::::zeros((64, 64)); - MemPoolITA::ita_load_2d(cpu, &mut k, k_addr, 64, 64, 4); + let mut k = Array2::::zeros((seq_len as usize, emb_len as usize)); + MemPoolITA::ita_load_2d(cpu, &mut k, k_addr, seq_len, emb_len, split_e); + debug!("[ITA, CPU {}] k.shape: {:?}", &cpu.hartid, k.shape()); + debug!("[ITA, CPU {}] k: {}", &cpu.hartid, k); - let mut w_k = Array3::::zeros((1, 64, 64)); - MemPoolITA::ita_load_3d(cpu, &mut w_k, w2_addr, 1, 64, 64, 1); - MemPoolITA::transpose_3d(&mut w_k, 1, 64, 64); + let mut w_k = Array3::::zeros((1, proj_len as usize, emb_len as usize)); + MemPoolITA::ita_load_3d(cpu, &mut w_k, w_k_addr, 1, proj_len, emb_len, 1); + w_k = MemPoolITA::transpose_2d_arrays(&mut w_k); + debug!("[ITA, CPU {}] w_k.shape: {:?}", &cpu.hartid, w_k.shape()); + debug!("[ITA, CPU {}] w_k: {}", &cpu.hartid, w_k); // Setup of matrices for value_projection_space_transformation - let mut b_v = Array3::::zeros((1, 64, 64)); - MemPoolITA::ita_load_3d(cpu, &mut b_v, b3_addr, 1, 64, 64, 4); - MemPoolITA::transpose_3d(&mut b_v, 1, 64, 64); + let mut b_v = Array2::::zeros((1, proj_len as usize)); + MemPoolITA::ita_load_2d_i32(cpu, &mut b_v, b_v_addr, 1, proj_len); + debug!("[ITA, CPU {}] b_v.shape: {:?}", &cpu.hartid, b_v.shape()); + debug!("[ITA, CPU {}] b_v: {}", &cpu.hartid, b_v); let mut v = k.clone(); - let mut w_v = Array3::::zeros((1, 64, 64)); - MemPoolITA::ita_load_3d(cpu, &mut w_v, w3_addr, 1, 64, 64, 1); - MemPoolITA::transpose_3d(&mut w_v, 1, 64, 64); + let mut w_v = Array3::::zeros((1, proj_len as usize, emb_len as usize)); + MemPoolITA::ita_load_3d(cpu, &mut w_v, w_v_addr, 1, proj_len, emb_len, 1); + w_v = MemPoolITA::transpose_2d_arrays(&mut w_v); + debug!("[ITA, CPU {}] w_v.shape: {:?}", &cpu.hartid, w_v.shape()); + debug!("[ITA, CPU {}] w_v: {}", &cpu.hartid, w_v); - let mut v_p = Array3::::zeros((1, 64, 64)); + let mut v_p = Array3::::zeros((1, seq_len as usize, proj_len as usize)); // matrices in the query_projection_space_transformation - let mut b_q = Array3::::zeros((1, 64, 64)); - MemPoolITA::ita_load_3d(cpu, &mut b_q, b1_addr, 1, 64, 64, 4); - let mut q_p = Array3::::zeros((1, 64, 64)); + let mut b_q = Array2::::zeros((1, proj_len as usize)); + MemPoolITA::ita_load_2d_i32(cpu, &mut b_q, b_q_addr, 1, proj_len); + debug!("[ITA, CPU {}] b_q.shape: {:?}", &cpu.hartid, b_q.shape()); + debug!("[ITA, CPU {}] b_q: {}", &cpu.hartid, b_q); + let mut q_p = Array3::::zeros((1, seq_len as usize, proj_len as usize)); // matrices in the key_projection_space_transformation - let mut b_k = Array3::::zeros((1, 64, 64)); - MemPoolITA::ita_load_3d(cpu, &mut b_k, b2_addr, 1, 64, 64, 4); + let mut b_k = Array2::::zeros((1, proj_len as usize)); + MemPoolITA::ita_load_2d_i32(cpu, &mut b_k, b_k_addr, 1, proj_len); + debug!("[ITA, CPU {}] b_k.shape: {:?}", &cpu.hartid, b_k.shape()); + debug!("[ITA, CPU {}] b_k: {}", &cpu.hartid, b_k); - let mut k_p = Array3::::zeros((1, 64, 64)); + let mut k_p = Array3::::zeros((1, seq_len as usize, proj_len as usize)); // matrices in the streaming_partial_softmax - let mut a_requant = Array3::::zeros((1, 64, 64)); - let mut a_partial_softmax = Array2::::zeros((64, 64)); + let mut a_requant = Array3::::zeros((1, seq_len as usize, seq_len as usize)); + let mut a_partial_softmax = Array2::::zeros((seq_len as usize, seq_len as usize)); // matrices in multi_head_computation - let mut out = Array3::::zeros((1, 64, 64)); - let mut b_o = Array3::::zeros((1, 64, 64)); - MemPoolITA::ita_load_3d(cpu, &mut b_o, b4_addr, 1, 64, 64, 4); - let mut w_o = Array3::::zeros((1, 64, 64)); - MemPoolITA::ita_load_3d(cpu, &mut w_o, w4_addr, 1, 64, 64, 1); - MemPoolITA::transpose_3d(&mut w_o, 1, 64, 64); + let mut out = Array3::::zeros((1, seq_len as usize, emb_len as usize)); + let mut b_o = Array2::::zeros((1, emb_len as usize)); + MemPoolITA::ita_load_2d_i32(cpu, &mut b_o, b_o_addr, 1, emb_len); + + debug!("[ITA, CPU {}] b_o.shape: {:?}", &cpu.hartid, b_o.shape()); + debug!("[ITA, CPU {}] b_o: {}", &cpu.hartid, b_o); + + let mut w_o = Array3::::zeros((1, emb_len as usize, proj_len as usize)); + MemPoolITA::ita_load_3d(cpu, &mut w_o, w_o_addr, 1, emb_len, proj_len, 1); + w_o = MemPoolITA::transpose_2d_arrays(&mut w_o); + debug!("[ITA, CPU {}] w_o.shape: {:?}", &cpu.hartid, w_o.shape()); + debug!("[ITA, CPU {}] w_o: {}", &cpu.hartid, w_o); // query_projection_space_transformation - // query_projection_space_transformation(&mut q_p, &mut q, &mut w_q, &mut b_q, 1); MemPoolITA::projection_space_transformation(&mut q_p, &mut q, &mut w_q, &mut b_q, 1); // requantization of q_p - let mut q_p_requant = Array3::::zeros((1, 64, 64)); + let mut q_p_requant = Array3::::zeros((1, seq_len as usize, proj_len as usize)); MemPoolITA::requantization_3d( &mut q_p, &mut q_p_requant, @@ -588,16 +737,12 @@ impl MemPoolITA { rqs_shift[0], rqs_add[0], ); - // debug!("q_p_requant: {}", q_p_requant); + debug!("[ITA, CPU {}] q_p_requant: {}", &cpu.hartid, q_p_requant); // key_projection_space_transformation - // key_projection_space_transformation(&mut k_p, &mut k, &mut w_k, &mut b_k, 1); - // debug!("k: {}", k); - // debug!("w_k: {}", w_k); - // debug!("b_k: {}", b_k); MemPoolITA::projection_space_transformation(&mut k_p, &mut k, &mut w_k, &mut b_k, 1); // requantization of k_p - let mut k_p_requant = Array3::::zeros((1, 64, 64)); + let mut k_p_requant = Array3::::zeros((1, seq_len as usize, proj_len as usize)); MemPoolITA::requantization_3d( &mut k_p, &mut k_p_requant, @@ -605,10 +750,10 @@ impl MemPoolITA { rqs_shift[1], rqs_add[1], ); - // debug!("k_p_requant: {}", k_p_requant); + debug!("[ITA, CPU {}] k_p_requant: {}", &cpu.hartid, k_p_requant); // query_key_correlation - let mut qk = Array3::::zeros((1, 64, 64)); + let mut qk = Array3::::zeros((1, seq_len as usize, seq_len as usize)); MemPoolITA::query_key_correlation(&mut q_p_requant, &mut k_p_requant, &mut qk); // requantization of qk MemPoolITA::requantization_3d( @@ -618,16 +763,20 @@ impl MemPoolITA { rqs_shift[2], rqs_add[2], ); - // debug!("a_requant: {}", a_requant); + debug!("[ITA, CPU {}] a_requant: {}", &cpu.hartid, a_requant); // streaming_partial_softmax - MemPoolITA::streaming_partial_softmax(&mut a_requant, &mut a_partial_softmax, 64); + MemPoolITA::streaming_partial_softmax( + &mut a_requant, + &mut a_partial_softmax, + seq_len, + processing_engines, + ); // value_projection_space_transformation - // value_projection_space_transformation(&mut v_p, &mut v, &mut w_v, &mut b_v, 1); MemPoolITA::projection_space_transformation(&mut v_p, &mut v, &mut w_v, &mut b_v, 1); // requantization of v_p - let mut v_p_requant = Array3::::zeros((1, 64, 64)); + let mut v_p_requant = Array3::::zeros((1, seq_len as usize, proj_len as usize)); MemPoolITA::requantization_3d( &mut v_p, &mut v_p_requant, @@ -635,17 +784,17 @@ impl MemPoolITA { rqs_shift[3], rqs_add[3], ); - // debug!("v_p_requant: {}", v_p_requant); + debug!("[ITA, CPU {}] v_p_requant: {}", &cpu.hartid, v_p_requant); // single_head_computation - let mut o_softmax = Array3::::zeros((1, 64, 64)); + let mut o_softmax = Array3::::zeros((1, seq_len as usize, proj_len as usize)); MemPoolITA::single_head_computation( &mut a_partial_softmax, &mut v_p_requant, &mut o_softmax, ); // requantization of o_softmax - let mut o_softmax_requant = Array3::::zeros((1, 64, 64)); + let mut o_softmax_requant = Array3::::zeros((1, seq_len as usize, proj_len as usize)); MemPoolITA::requantization_3d( &mut o_softmax, &mut o_softmax_requant, @@ -653,12 +802,15 @@ impl MemPoolITA { rqs_shift[4], rqs_add[4], ); - // debug!("o_softmax_requant: {}", o_softmax_requant); + debug!( + "[ITA, CPU {}] o_softmax_requant: {}", + &cpu.hartid, o_softmax_requant + ); // multi_head_computation MemPoolITA::multi_head_computation(&mut o_softmax_requant, &mut out, &mut w_o, &mut b_o, 1); // parallel requantization of out - let mut out_requant = Array2::::zeros((64, 64)); + let mut out_requant = Array2::::zeros((seq_len as usize, emb_len as usize)); MemPoolITA::parallel_requantize3d( &mut out, &mut out_requant, @@ -666,18 +818,13 @@ impl MemPoolITA { rqs_shift[5], rqs_add[5], ); - // debug!("out_requant: {}", out_requant); - - // for j in 0..out_requant.shape()[1] { - // let row = out_requant.slice(s![j, ..]); - // debug!("out[{},:]:\n{}", j, row); - // } + debug!("[ITA, CPU {}] out_requant: {}", &cpu.hartid, out_requant); // Store the output - MemPoolITA::ita_store_2d(cpu, &out_requant, out_address, 64, 64, 1); + MemPoolITA::ita_store_2d(cpu, &out_requant, out_address, seq_len, emb_len, 1); } - fn requantize_row(element: i32, eps_mult: u8, right_shift: u8, add: i8) -> i8 { + fn requantize_row(element: i32, eps_mult: u8, right_shift: u8, add: i32) -> i8 { let mut shifted = ((element * (eps_mult as i32)) >> (right_shift as i32)) + (add as i32); // Perform rounding half away from zero @@ -700,10 +847,8 @@ impl MemPoolITA { m_requant: &mut Array3, eps_mult: u8, right_shift: u8, - add: i8, + add: i32, ) { - // debug!("===================== 3D Requantization ====================="); - // Loop over the number of heads for i in 0..m.shape()[0] { // Loop over the head dimension @@ -724,10 +869,9 @@ impl MemPoolITA { m_requant: &mut Array2, eps_mult: u8, right_shift: u8, - add: i8, + add: i32, ) { - // debug!("===================== Parallel 3D Requantization ====================="); - m_requant.fill(add); + m_requant.fill(add as i8); for i in 0..m.shape()[0] { for j in 0..m.shape()[1] { let row = m.slice(s![i, j, ..]); @@ -752,35 +896,35 @@ impl MemPoolITA { p: &mut Array3, m: &mut Array2, w: &mut Array3, - b: &mut Array3, + b: &mut Array2, bias: u8, ) { - // debug!("===================== Projection Space Transformation ====================="); - if bias == 1 { - for i in 0..p.shape()[0] { - for j in 0..p.shape()[1] { - for k in 0..p.shape()[2] { - p[[i, j, k]] = b[[i, j, k]] as i32; - for l in 0..m.shape()[1] { - p[[i, j, k]] += m[[j, l]] as i32 * w[[i, l, k]] as i32; - } - } - } - } - } else { - for i in 0..p.shape()[0] { - for j in 0..p.shape()[1] { - for k in 0..p.shape()[2] { - p[[i, j, k]] = 0; - for l in 0..m.shape()[1] { - p[[i, j, k]] += m[[j, l]] as i32 * w[[i, l, k]] as i32; - } - } - } + info!("===================== Projection Space Transformation ====================="); + info!("p shape: {:?}", p.shape()); + info!("m shape: {:?}", m.shape()); + info!("w: {:?}", w.shape()); + info!("b: {:?}", b.shape()); + + // Calculate p[h] = m * W[h] + b[h] for each head h + + let d1 = m.shape(); + let d2 = w.shape(); + + assert_eq!(d1[1], d2[1], "Matrices dimensions don't match"); + + for i in 0..p.shape()[0] { + let slice_a = m.map(|x| *x as i32); + let slice_b = w.slice(s![i, .., ..]).map(|x| *x as i32); + let slice_c = b.slice(s![i, ..]).map(|x| *x); + let slice_c = slice_c.broadcast((d1[0], d2[2])).unwrap().map(|x| *x); + let mut mult_a_b = slice_a.dot(&slice_b); + + if bias == 1 { + mult_a_b = mult_a_b + slice_c; } - } - // debug!("projected matrix: {:?}", p); + p.slice_mut(s![i, .., ..]).assign(&mult_a_b); + } } fn query_key_correlation( @@ -788,63 +932,67 @@ impl MemPoolITA { kp_requant: &mut Array3, qk: &mut Array3, ) { - // debug!("===================== Query Key Correlation ====================="); + info!("===================== Query Key Correlation ====================="); + info!("qp_requant shape: {:?}", qp_requant.shape()); + info!("kp_requant shape: {:?}", kp_requant.shape()); + info!("qk shape: {:?}", qk.shape()); + + let d1 = qp_requant.shape(); + let d2 = kp_requant.shape(); + + assert_eq!(d1[2], d2[2], "Matrices dimensions don't match"); + + // Calculate qk[h] = qp_requant[h] * kp_requant[h].T for each head h + let kp_requant_transposed = MemPoolITA::transpose_2d_arrays(kp_requant); - // Loop over the number of heads for i in 0..qk.shape()[0] { - // Loop over the number of queries - for j in 0..qk.shape()[1] { - // Loop over the number of keys - for k in 0..qk.shape()[2] { - qk[[i, j, k]] = 0; - // Loop over the number of features - for l in 0..qk.shape()[1] { - qk[[i, j, k]] += qp_requant[[i, j, l as usize]] as i32 - * kp_requant[[i, k, l as usize]] as i32; - } - } - } - } + let slice_a = qp_requant.slice(s![i, .., ..]).map(|x| *x as i32); + let slice_b = kp_requant_transposed + .slice(s![i, .., ..]) + .map(|x| *x as i32); + let mult_a_b = slice_a.dot(&slice_b); - // debug!("qk: {:?}", qk); + qk.slice_mut(s![i, .., ..]).assign(&mult_a_b); + } } //Compute the approximated softmax function. fn streaming_partial_softmax( a_requant: &mut Array3, a_partial_softmax: &mut Array2, - seq_len: i32, + seq_len: u32, + processing_engines: u32, ) { - // debug!("===================== Streaming Partial SoftMax ====================="); - // let log2e: f64 = f64::log2(f64::exp(1.0)); // let b = 8; // let eps_x = b as f64 / (2.0f64.powi(b) * log2e); let mut exp_partial_sum = Array1::::zeros(seq_len as usize); - let mut max = Array1::::zeros(64); - let mut current_max = Array1::::zeros(64); - - for i in 0..4 { - let a_requant_slice = a_requant.slice_mut(s![0, .., i * 16..(i + 1) * 16]); + let mut max = Array1::::zeros(seq_len as usize); + let mut current_max = Array1::::zeros(seq_len as usize); + let _processing_engines = processing_engines as usize; + let groups = seq_len as usize / _processing_engines; + + for i in 0..groups { + let a_requant_slice = a_requant.slice_mut(s![ + 0, + .., + i * _processing_engines..(i + 1) * _processing_engines + ]); for n in 0..a_requant_slice.nrows() { current_max[[n]] = a_requant_slice.row(n).iter().copied().max().unwrap() as i8; } for j in 0..seq_len { - let mut shift_sum; + let mut shift_sum: u8; if i == 0 || current_max[j as usize] > max[[j as usize]] { if i == 0 { shift_sum = 0; } else { - shift_sum = (current_max[j as usize] - max[[j as usize]]) / 32; - // if (((current_max[j as usize] - max[[j as usize]]) / 32) - shift_sum) as f64 - // >= 0.5 - // { - // shift_sum += 1; - // } let shift_int = (current_max[j as usize] as i32) - (max[[j as usize]] as i32); + shift_sum = (shift_int / 32) as u8; + if shift_int % 32 >= 16 { shift_sum += 1; } @@ -855,7 +1003,11 @@ impl MemPoolITA { } let qb = a_requant - .slice_mut(s![0, .., i * 16..(i + 1) * 16]) + .slice_mut(s![ + 0, + .., + i * _processing_engines..(i + 1) * _processing_engines + ]) .mapv(|x| x as i32 - max[[j as usize]] as i32); let mut qexp = 0; @@ -876,7 +1028,7 @@ impl MemPoolITA { } for j in 0..seq_len { let factor = - ((2.0f64.powi(8) - 1.0) * 2.0f64.powi(10)) as i32 / exp_partial_sum[j as usize]; + ((2.0f64.powi(7) - 1.0) * 2.0f64.powi(10)) as i32 / exp_partial_sum[j as usize]; for k in 0..seq_len { let mut shift = (((max[j as usize] as i32) - (a_requant[[0, j as usize, k as usize]] as i32)) @@ -890,8 +1042,6 @@ impl MemPoolITA { (factor as i32) / 2.0f64.powi(shift) as i32; } } - - // debug!("a_partial_softmax: {}", a_partial_softmax); } fn single_head_computation( @@ -899,8 +1049,6 @@ impl MemPoolITA { vp_requant: &mut Array3, o_softmax: &mut Array3, ) { - // debug!("===================== Single Head Computation ====================="); - // Loop over the number of heads for i in 0..o_softmax.shape()[0] { // Loop over the number of queries @@ -916,45 +1064,39 @@ impl MemPoolITA { } } } - - // debug!("o_softmax: {:?}", o_softmax); } fn multi_head_computation( o_softmax_requant: &mut Array3, out: &mut Array3, w_o: &mut Array3, - b_o: &mut Array3, + b_o: &mut Array2, bias: u8, ) { - // debug!("===================== Multi Head Computation ====================="); - - if bias == 1 { - for i in 0..out.shape()[0] { - for j in 0..out.shape()[1] { - for k in 0..out.shape()[2] { - out[[i, j, k]] = b_o[[i, j, k]] as i32; - for l in 0..out.shape()[1] { - out[[i, j, k]] += - o_softmax_requant[[i, j, l]] as i32 * w_o[[i, l, k]] as i32; - } - } - } + info!("===================== Multi Head Computation ====================="); + info!("o_softmax_requant shape: {:?}", o_softmax_requant.shape()); + info!("out shape: {:?}", out.shape()); + info!("w_o shape: {:?}", w_o.shape()); + info!("b_o shape: {:?}", b_o.shape()); + + let d1 = o_softmax_requant.shape(); + let d2 = w_o.shape(); + + assert_eq!(d1[2], d2[1], "Matrices dimensions don't match"); + + // Calculate out[h] = o_softmax_requant[h] * W_o[h] + b_o[h] for each head h + for i in 0..out.shape()[0] { + let slice_a = o_softmax_requant.slice(s![i, .., ..]).map(|x| *x as i32); + let slice_b = w_o.slice(s![i, .., ..]).map(|x| *x as i32); + let slice_c = b_o.slice(s![i, ..]).map(|x| *x); + let slice_c = slice_c.broadcast((d1[0], d2[2])).unwrap().map(|x| *x); + let mut mult_a_b = slice_a.dot(&slice_b); + + if bias == 1 { + mult_a_b = mult_a_b + slice_c; } - } else { - for i in 0..out.shape()[0] { - for j in 0..out.shape()[1] { - for k in 0..out.shape()[2] { - out[[i, j, k]] = 0; - for l in 0..out.shape()[1] { - out[[i, j, k]] += - o_softmax_requant[[i, j, l]] as i32 * w_o[[i, l, k]] as i32; - } - } - } - } - } - // debug!("out: {:?}", out); + out.slice_mut(s![i, .., ..]).assign(&mult_a_b); + } } }