diff --git a/.github/scripts/build-toolchains.sh b/.github/scripts/build-toolchains.sh deleted file mode 100755 index fa6017ea..00000000 --- a/.github/scripts/build-toolchains.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -#------------------------------------------------------------- -# create the riscv tools binaries from ucb-bar/chipyard with rocket-chip hash given by riscv-boom -# -# run location: circle ci docker image -# usage: -# $1 - name of the toolchain to build -#------------------------------------------------------------- - -# turn echo on and error on earliest command -set -ex - -# get shared variables -SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" -source $SCRIPT_DIR/defaults.sh - -INSTALL_DIR="$HOME/$1-install" - -if [ ! -d "$INSTALL_DIR" ]; then - cd $HOME - - git clone --progress --verbose https://github.com/ucb-bar/chipyard.git chipyard - cd $LOCAL_CHIPYARD_DIR - - echo "Checking out Chipyard version: $(cat $LOCAL_CHECKOUT_DIR/CHIPYARD.hash)" - git fetch - git checkout $(cat $LOCAL_CHECKOUT_DIR/CHIPYARD.hash) - - cd $HOME - - # init all submodules including the tools (doesn't use CI_MAKE_PROC due to mem. constraints) - CHIPYARD_DIR="$LOCAL_CHIPYARD_DIR" NPROC=$CI_MAKE_NPROC $LOCAL_CHIPYARD_DIR/scripts/build-toolchains.sh esp-tools -fi - diff --git a/.github/scripts/defaults.sh b/.github/scripts/defaults.sh index 6a02a220..e403fc89 100755 --- a/.github/scripts/defaults.sh +++ b/.github/scripts/defaults.sh @@ -28,6 +28,7 @@ LOCAL_ESP_DIR=$HOME/esp-tools-install LOCAL_CHIPYARD_DIR=$HOME/chipyard LOCAL_SIM_DIR=$LOCAL_CHIPYARD_DIR/sims/verilator LOCAL_VERILATOR_DIR=$HOME/verilator-install +LOCAL_CONDA=/opt/conda/ echo "::set-output name=LOCAL_WORK_DIR::$LOCAL_WORK_DIR" echo "::set-output name=LOCAL_CHECKOUT_DIR::$LOCAL_CHECKOUT_DIR" @@ -36,3 +37,4 @@ echo "::set-output name=LOCAL_ESP_DIR::$LOCAL_ESP_DIR" echo "::set-output name=LOCAL_CHIPYARD_DIR::$LOCAL_CHIPYARD_DIR" echo "::set-output name=LOCAL_SIM_DIR::$LOCAL_SIM_DIR" echo "::set-output name=LOCAL_VERILATOR_DIR::$LOCAL_VERILATOR_DIR" +echo "::set-output name=LOCAL_CONDA::$LOCAL_CONDA" diff --git a/.github/scripts/do-rtl-build.sh b/.github/scripts/do-rtl-build.sh index 1b93655e..38651571 100755 --- a/.github/scripts/do-rtl-build.sh +++ b/.github/scripts/do-rtl-build.sh @@ -7,20 +7,12 @@ set -ex SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" source $SCRIPT_DIR/defaults.sh -rm -rf $LOCAL_CHIPYARD_DIR/generators/gemmini/* -cd $LOCAL_CHECKOUT_DIR -git submodule update --init --recursive software/gemmini-rocc-tests -mv -f $LOCAL_CHECKOUT_DIR/* $LOCAL_CHIPYARD_DIR/generators/gemmini/ +source $SCRIPT_DIR/enable-conda.sh +cd $LOCAL_CHIPYARD_DIR +source env.sh -TOOLS_DIR=$LOCAL_ESP_DIR -LD_LIB_DIR=$LOCAL_ESP_DIR/lib - -# enter the verilator directory and build the specific config on remote server +cd $LOCAL_SIM_DIR make -C $LOCAL_SIM_DIR clean -export RISCV=$TOOLS_DIR -export LD_LIBRARY_PATH=$LD_LIB_DIR -export PATH=$LOCAL_VERILATOR_DIR/bin:$PATH -export VERILATOR_ROOT=$LOCAL_VERILATOR_DIR -export COURSIER_CACHE=$LOCAL_WORK_DIR/.coursier-cache make -j$LOCAL_MAKE_NPROC -C $LOCAL_SIM_DIR VERILATOR_OPT_FLAGS="-O0 -OG" JAVA_OPTS="-Xmx2500M -Xss8M" SBT_OPTS="-Dsbt.ivy.home=$LOCAL_CHIPYARD_DIR/.ivy2 -Dsbt.supershell=false -Dsbt.global.base=$LOCAL_CHIPYARD_DIR/.sbt -Dsbt.boot.directory=$LOCAL_CHIPYARD_DIR/.sbt/boot" CONFIG=GemminiRocketConfig + diff --git a/.github/scripts/enable-conda.sh b/.github/scripts/enable-conda.sh new file mode 100644 index 00000000..184ead9b --- /dev/null +++ b/.github/scripts/enable-conda.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +export PATH="$LOCAL_CONDA/bin:$PATH" +conda init +source ~/.bashrc +conda activate base +if ! { conda env list | grep 'chipyard'; } >/dev/null 2>&1; then + conda create -n chipyard + conda activate chipyard + conda install -c conda-forge conda-lock +fi +conda activate chipyard + diff --git a/.github/scripts/install-gemmini.sh b/.github/scripts/install-gemmini.sh new file mode 100755 index 00000000..0fa6460d --- /dev/null +++ b/.github/scripts/install-gemmini.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +#------------------------------------------------------------- +# installs gemmini +# +# run location: circle ci docker image +#------------------------------------------------------------- + +# turn echo on and error on earliest command +set -ex + +# get shared variables +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +source $SCRIPT_DIR/defaults.sh + +source $SCRIPT_DIR/enable-conda.sh + +cd $HOME +rm -rf chipyard +git clone --progress --verbose https://github.com/ucb-bar/chipyard.git chipyard +cd $LOCAL_CHIPYARD_DIR + +git fetch +git checkout $(cat $LOCAL_CHECKOUT_DIR/CHIPYARD.hash) + +./build-setup.sh esp-tools + +source env.sh + +cd toolchains/esp-tools/riscv-isa-sim/build +git checkout $(cat $LOCAL_CHECKOUT_DIR/SPIKE.hash) +make && make install + +cd $LOCAL_CHECKOUT_DIR +chown -R $(whoami) . +git config --global --add safe.directory $LOCAL_CHECKOUT_DIR +git submodule update --init --recursive software/gemmini-rocc-tests +rm -rf $LOCAL_CHIPYARD_DIR/generators/gemmini/* $LOCAL_CHIPYARD_DIR/generators/gemmini/.git* +mv -f $LOCAL_CHECKOUT_DIR/* $LOCAL_CHECKOUT_DIR/.git* $LOCAL_CHIPYARD_DIR/generators/gemmini/ + diff --git a/.github/scripts/install-verilator.sh b/.github/scripts/install-verilator.sh deleted file mode 100755 index b996b4d0..00000000 --- a/.github/scripts/install-verilator.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# move verilator to the remote server - -# turn echo on and error on earliest command -set -ex - -# get shared variables -SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" -source $SCRIPT_DIR/defaults.sh - -if [ ! -d "$LOCAL_VERILATOR_DIR" ]; then - git clone http://git.veripool.org/git/verilator $LOCAL_VERILATOR_DIR - cd $LOCAL_VERILATOR_DIR - git checkout $VERILATOR_VERSION - autoconf - export VERILATOR_ROOT=$LOCAL_VERILATOR_DIR - ./configure - make -j$LOCAL_MAKE_NPROC -fi diff --git a/.github/scripts/prepare-for-rtl-build.sh b/.github/scripts/prepare-for-rtl-build.sh deleted file mode 100755 index df3ac470..00000000 --- a/.github/scripts/prepare-for-rtl-build.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -# turn echo on and error on earliest command -set -ex - -# get shared variables -SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" -source $SCRIPT_DIR/defaults.sh - -# check to see if both dirs exist -if [ ! -d "$LOCAL_CHIPYARD_DIR" ]; then - cd $HOME - - git clone --progress --verbose https://github.com/ucb-bar/chipyard.git chipyard - cd $LOCAL_CHIPYARD_DIR - - echo "Checking out Chipyard version: $(cat $LOCAL_CHECKOUT_DIR/CHIPYARD.hash)" - git fetch - git checkout $(cat $LOCAL_CHECKOUT_DIR/CHIPYARD.hash) - - # init all submodules (according to what chipyard wants) - ./scripts/init-submodules-no-riscv-tools.sh -fi diff --git a/.github/scripts/remove-chipyard.sh b/.github/scripts/remove-chipyard.sh new file mode 100755 index 00000000..8b82019e --- /dev/null +++ b/.github/scripts/remove-chipyard.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -ex + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +source $SCRIPT_DIR/defaults.sh + +rm -rf $LOCAL_CHIPYARD_DIR +rm -rf $LOCAL_CONDA + diff --git a/.github/scripts/run-tests-rtl.sh b/.github/scripts/run-tests-rtl.sh index c5907ddd..47a87ff1 100755 --- a/.github/scripts/run-tests-rtl.sh +++ b/.github/scripts/run-tests-rtl.sh @@ -5,9 +5,10 @@ set -ex SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" source $SCRIPT_DIR/defaults.sh +source $SCRIPT_DIR/enable-conda.sh -TOOLS_DIR=$LOCAL_ESP_DIR -PATH=$PATH:$LOCAL_ESP_DIR/bin +cd $LOCAL_CHIPYARD_DIR +source env.sh cd $LOCAL_CHIPYARD_DIR/generators/gemmini/software/gemmini-rocc-tests CFLAGS=-DFAST ./build.sh @@ -15,4 +16,3 @@ CFLAGS=-DFAST ./build.sh cd build make test-baremetal-bareMetalC RUNNER="'make -C $LOCAL_CHIPYARD_DIR/sims/verilator/ CONFIG=GemminiRocketConfig run-binary-hex BINARY='" - diff --git a/.github/scripts/run-tests-spike.sh b/.github/scripts/run-tests-spike.sh index 9f933aaf..93288a75 100755 --- a/.github/scripts/run-tests-spike.sh +++ b/.github/scripts/run-tests-spike.sh @@ -5,23 +5,14 @@ set -ex SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" source $SCRIPT_DIR/defaults.sh +source $SCRIPT_DIR/enable-conda.sh -# clone and build our version of spike -TOOLS_DIR=$LOCAL_ESP_DIR -PATH=$PATH:$LOCAL_ESP_DIR/bin - -git clone https://github.com/ucb-bar/esp-isa-sim.git -cd esp-isa-sim -git checkout $(cat $LOCAL_CHECKOUT_DIR/SPIKE.hash) -cp $LOCAL_CHIPYARD_DIR/generators/gemmini/software/gemmini-rocc-tests/include/gemmini_params.h gemmini/ - -mkdir build -cd build -../configure --prefix=$TOOLS_DIR -make -j8 install +cd $LOCAL_CHIPYARD_DIR +source env.sh cd $LOCAL_CHIPYARD_DIR/generators/gemmini/software/gemmini-rocc-tests ./build.sh cd build make test-baremetal + diff --git a/.github/workflows/config.yml b/.github/workflows/config.yml index cae5f7fd..f4a2dfda 100644 --- a/.github/workflows/config.yml +++ b/.github/workflows/config.yml @@ -1,11 +1,11 @@ name: Gemmini CI on: [push] jobs: - install-esp-toolchain: - name: install-esp-toolchain + install-gemmini: + name: gemmini-install runs-on: ubuntu-latest container: - image: ucbbar/chipyard-ci-image:554b436 + image: ucbbar/chipyard-ci-image:3f9150 options: --entrypoint /bin/bash steps: - name: checkout @@ -14,47 +14,23 @@ jobs: run: .github/scripts/defaults.sh id: get-paths - - name: toolchain-build - run: .github/scripts/build-toolchains.sh esp-tools + - name: install gemmini + run: .github/scripts/install-gemmini.sh - - name: cache esp-toolchain install - uses: actions/cache@v2 - with: - path: ${{ steps.get-paths.outputs.LOCAL_ESP_DIR }} - key: esp-tools-install-${{ github.ref }}-${{ github.sha }} - - prepare-build-environment: - name: prepare-build-environment - runs-on: ubuntu-latest - container: - image: ucbbar/chipyard-ci-image:554b436 - options: --entrypoint /bin/bash - steps: - - name: checkout - uses: actions/checkout@v2 - - name: get paths - run: .github/scripts/defaults.sh - id: get-paths - - - name: setup build environment - run: .github/scripts/prepare-for-rtl-build.sh - - name: install verilator - run: .github/scripts/install-verilator.sh - - - name: cache prepare-build-environment install + - name: cache gemmini install uses: actions/cache@v2 with: path: | ${{ steps.get-paths.outputs.LOCAL_CHIPYARD_DIR }} - ${{ steps.get-paths.outputs.LOCAL_VERILATOR_DIR }} - key: prepare-build-environment-${{ github.ref }}-${{ github.sha }} + ${{ steps.get-paths.outputs.LOCAL_CONDA }} + key: gemmini-install-${{ github.ref }}-${{ github.sha }} - prepare-gemmini-config: - name: prepare-gemmini-config - runs-on: ubuntu-latest - needs: [prepare-build-environment, install-esp-toolchain] + build-gemmini-config: + name: build-gemmini-config + runs-on: self-hosted + needs: install-gemmini container: - image: ucbbar/chipyard-ci-image:554b436 + image: ucbbar/chipyard-ci-image:3f9150 options: --entrypoint /bin/bash steps: - name: checkout @@ -63,38 +39,34 @@ jobs: run: .github/scripts/defaults.sh id: get-paths - - name: restore cache esp-toolchain install - uses: actions/cache@v2 - with: - path: ${{ steps.get-paths.outputs.LOCAL_ESP_DIR }} - key: esp-tools-install-${{ github.ref }}-${{ github.sha }} + - name: remove chipyard + run: .github/scripts/remove-chipyard.sh - - name: restore cache prepare-build-environment install + - name: restore cache gemmini install uses: actions/cache@v2 with: path: | ${{ steps.get-paths.outputs.LOCAL_CHIPYARD_DIR }} - ${{ steps.get-paths.outputs.LOCAL_VERILATOR_DIR }} - key: prepare-build-environment-${{ github.ref }}-${{ github.sha }} + ${{ steps.get-paths.outputs.LOCAL_CONDA }} + key: gemmini-install-${{ github.ref }}-${{ github.sha }} - name: Building Gemmini Config using Verilator run: .github/scripts/do-rtl-build.sh - - name: cache prepare-gemmini-config install + - name: cache build-gemmini-config install uses: actions/cache@v2 with: path: | ${{ steps.get-paths.outputs.LOCAL_CHIPYARD_DIR }} - ${{ steps.get-paths.outputs.LOCAL_VERILATOR_DIR }} - ${{ steps.get-paths.outputs.LOCAL_ESP_DIR }} - key: prepare-gemmini-config-${{ github.ref }}-${{ github.sha }} + ${{ steps.get-paths.outputs.LOCAL_CONDA }} + key: build-gemmini-config-${{ github.ref }}-${{ github.sha }} spike-run-tests: name: spike-run-tests runs-on: ubuntu-latest - needs: prepare-gemmini-config + needs: install-gemmini container: - image: ucbbar/chipyard-ci-image:554b436 + image: ucbbar/chipyard-ci-image:3f9150 options: --entrypoint /bin/bash steps: - name: checkout @@ -103,14 +75,16 @@ jobs: run: .github/scripts/defaults.sh id: get-paths - - name: restore cache prepare-gemmini-config install + - name: remove chipyard + run: .github/scripts/remove-chipyard.sh + + - name: restore cache gemmini install uses: actions/cache@v2 with: path: | ${{ steps.get-paths.outputs.LOCAL_CHIPYARD_DIR }} - ${{ steps.get-paths.outputs.LOCAL_VERILATOR_DIR }} - ${{ steps.get-paths.outputs.LOCAL_ESP_DIR }} - key: prepare-gemmini-config-${{ github.ref }}-${{ github.sha }} + ${{ steps.get-paths.outputs.LOCAL_CONDA }} + key: gemmini-install-${{ github.ref }}-${{ github.sha }} - name: run-tests run: .github/scripts/run-tests-spike.sh @@ -118,9 +92,9 @@ jobs: rtl-run-tests: name: rtl-run-tests runs-on: ubuntu-latest - needs: prepare-gemmini-config + needs: build-gemmini-config container: - image: ucbbar/chipyard-ci-image:554b436 + image: ucbbar/chipyard-ci-image:3f9150 options: --entrypoint /bin/bash steps: - name: checkout @@ -129,14 +103,16 @@ jobs: run: .github/scripts/defaults.sh id: get-paths - - name: restore cache prepare-gemmini-config install + - name: remove chipyard + run: .github/scripts/remove-chipyard.sh + + - name: restore cache build-gemmini-config install uses: actions/cache@v2 with: path: | ${{ steps.get-paths.outputs.LOCAL_CHIPYARD_DIR }} - ${{ steps.get-paths.outputs.LOCAL_VERILATOR_DIR }} - ${{ steps.get-paths.outputs.LOCAL_ESP_DIR }} - key: prepare-gemmini-config-${{ github.ref }}-${{ github.sha }} + ${{ steps.get-paths.outputs.LOCAL_CONDA }} + key: build-gemmini-config-${{ github.ref }}-${{ github.sha }} - name: run-tests run: .github/scripts/run-tests-rtl.sh diff --git a/CHIPYARD.hash b/CHIPYARD.hash index f41949c3..7fb91902 100644 --- a/CHIPYARD.hash +++ b/CHIPYARD.hash @@ -1 +1 @@ -117624d8eea27bafd613eec09e9b9b3e31239e08 +004297b6a8c01be1b2110c4cf4f9393ae1ff8805 diff --git a/README.md b/README.md index 0ffc6f2b..bbd1a6f9 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ We provide here a quick guide to installing Gemmini's dependencies (Chipyard and Dependencies --------- -Before beginning, install the [Chipyard dependencies](https://chipyard.readthedocs.io/en/latest/Chipyard-Basics/Initial-Repo-Setup.html#requirements) that are described here. +Before beginning, install the [Chipyard dependencies](https://chipyard.readthedocs.io/en/latest/Chipyard-Basics/Initial-Repo-Setup.html#default-requirements-installation). Installing Chipyard and Spike ----------------------------- @@ -38,20 +38,29 @@ Run these steps to install Chipyard and Spike (make sure to checkout the correct ```shell git clone https://github.com/ucb-bar/chipyard.git cd chipyard -git checkout 117624d8eea27bafd613eec09e9b9b3e31239e08 -./scripts/init-submodules-no-riscv-tools.sh -./scripts/build-toolchains.sh esp-tools +git checkout 1.8.1 +./build-setup.sh esp-tools source env.sh cd generators/gemmini -git fetch && git checkout v0.6.4 -git submodule update +git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" +git fetch && git checkout v0.7.0 +git submodule update --init --recursive + +SPIKE_HASH=$(cat SPIKE.hash) cd - cd toolchains/esp-tools/riscv-isa-sim/build -git fetch && git checkout 090e82c473fd28b4eb2011ffcd771ead6076faab +git fetch && git checkout $SPIKE_HASH make && make install + +# The final step is only necessary if you want to run MIDAS simulations with +# realistic DRAM models +cd - +cd sims/firesim +source sourceme-f1-manager.sh --skip-ssh-setup # Ignore error messages from this command +./build-setup.sh --library --skip-validate ``` Setting Up Gemmini @@ -141,7 +150,7 @@ cd chipyard/generators/gemmini Next steps -------- -Check out [our IISWC 2021 tutorial](https://sites.google.com/berkeley.edu/gemminitutorialiiswc2021/) to learn how to: +Check out our [MLSys 2022 tutorial](https://sites.google.com/berkeley.edu/gemmini-tutorial-mlsys-2022) (or our earlier but more out-of-date [IISWC 2021 tutorial](https://sites.google.com/berkeley.edu/gemminitutorialiiswc2021/)) to learn how to: * build different types of diverse accelerators using Gemmini. * add custom datatypes to Gemmini. * write your own Gemmini programs. @@ -464,7 +473,7 @@ When calling `config_mvin` (described below), the programmer can choose which `m **Format:** `config_ex rs1 rs2` - `rs1[1:0]` must be `00` - `rs1[2]` determines if output (0) or weight (1) stationary -- `rs1[4:3]` = activation function: either relu (1), relu6 (2), or no activation function (0) +- `rs1[3]` = activation function: either relu (1) or no activation function (0) - `rs1[8]` = should A be transposed? - `rs1[9]` = should B be transposed? - `rs1[31:16]` = the stride (in scratchpad addresses) by which the rows of A are fed into the systolic array. @@ -475,8 +484,6 @@ If the stride is 2, then we feed every other row into the systolic array instead - In the default config, `rs1[63:32]` is of type `float32` - `rs2[31:0]` = the number of bits by which the accumulated result of a matmul is right-shifted when leaving the systolic array - This parameter is only relevant in output-stationary mode, when partial sums must be accumulated within the systolic array itself, and scaled-down when leaving the systolic array and being written into the scratchpad. -- `rs2[63:32]` = the number of bits by which 6 should be left-shifted before applying relu6 - - This parameter is ignored if the relu6 activation function is not being used. - `funct` = 0 **Action:** mode <= rs1(2); shift <= rs2; A_stride <= rs1[31:16] @@ -530,6 +537,12 @@ The parameters controlling this feature are: **Action:** stride <= rs2; max-pooling parameters <= rs1 +### `config_norm` configures normalization commands +**Format:** `config_norm rs1 rs2` + +`config_norm` is an **experimental** command added primarily to support an integer-only variant of BERT called [I-BERT](https://arxiv.org/abs/2101.01321) on Gemmini. +The command allows users to set scalar constants that are used by I-BERT's GELU, layernorm, and softmax variants. + ### `flush` flushes the TLB **Format:** `flush rs1` - `rs1` = If `rs1[0]` is 1, then the current TLB request is skipped (if it has hit a page-fault and is waiting for an interrupt). diff --git a/SPIKE.hash b/SPIKE.hash index f08ac921..8cbb8d37 100644 --- a/SPIKE.hash +++ b/SPIKE.hash @@ -1 +1 @@ -090e82c473fd28b4eb2011ffcd771ead6076faab +051d820f08be84d069993de4375d29c91eb2f577 diff --git a/scripts/build-midas.sh b/scripts/build-midas.sh index c966513c..7e9da811 100755 --- a/scripts/build-midas.sh +++ b/scripts/build-midas.sh @@ -53,8 +53,10 @@ if [ dram_model == "" ]; then echo DRAM model must be provided. fi +export SYSLIBS=" $SYSLIBS -l:libdwarf.so -l:libelf.so -lz -lgmp " + cd ../../sims/firesim/ -source sourceme-f1-manager.sh &> build.log +source sourceme-f1-manager.sh --skip-ssh-setup &> build.log cd sim/ make ${simulator}${debug} TARGET_CONFIG=${dram_model}_WithDefaultFireSimBridges_WithFireSimConfigTweaks_chipyard.CustomGemminiSoCConfig diff --git a/scripts/build-onnx-inference.sh b/scripts/build-onnx-inference.sh index 23742f5c..07999b29 100755 --- a/scripts/build-onnx-inference.sh +++ b/scripts/build-onnx-inference.sh @@ -1,7 +1,8 @@ #!/bin/bash -cd /root/chipyard/generators/gemmini/software/onnxruntime-riscv/ +cd ./software/onnxruntime-riscv/ rm -rf ./build/ ./build.sh --parallel --enable_training --config=Debug --cmake_extra_defines onnxruntime_USE_SYSTOLIC=ON onnxruntime_SYSTOLIC_INT8=ON onnxruntime_SYSTOLIC_FP32=OFF cd ./systolic_runner/imagenet_runner/ ./build.sh --parallel --enable_training --config=Debug + diff --git a/scripts/build-onnx-training.sh b/scripts/build-onnx-training.sh index 55c9bc7b..bcb45565 100755 --- a/scripts/build-onnx-training.sh +++ b/scripts/build-onnx-training.sh @@ -1,6 +1,6 @@ #!/bin/bash -cd /root/chipyard/generators/gemmini/software/onnxruntime-riscv/ +cd ./software/onnxruntime-riscv/ rm -rf ./build/ ./build.sh --parallel --enable_training --config=Debug --cmake_extra_defines onnxruntime_USE_SYSTOLIC=ON onnxruntime_SYSTOLIC_INT8=OFF onnxruntime_SYSTOLIC_FP32=ON cd ./systolic_runner/imagenet_trainer/ diff --git a/scripts/build-vcs.sh b/scripts/build-vcs.sh index e3213521..23f159b0 100755 --- a/scripts/build-vcs.sh +++ b/scripts/build-vcs.sh @@ -4,21 +4,24 @@ help () { echo "Build a cycle-accurate VCS simulator for RISCV Gemmini programs," echo 'matching `customConfig` in `configs/GemminiCustomConfigs.scala`.' echo - echo "Usage: $0 [-h|--help] [--debug]" + echo "Usage: $0 [-h|--help] [--debug] [-j [N]]" echo echo "Options:" echo " debug Builds a VCS simulator which generates waveforms. Without this" echo " option, the simulator will not generate any waveforms." + echo " j [N] Allow N jobs at once. Default is 1." exit } show_help=0 debug="" +j="1" while [ $# -gt 0 ] ; do case $1 in -h | --help) show_help=1 ;; - --debug) debug="debug" + --debug) debug="debug" ;; + -j) j=$2; shift esac shift @@ -29,5 +32,5 @@ if [ $show_help -eq 1 ]; then fi cd ../../sims/vcs/ -make ${debug} CONFIG=CustomGemminiSoCConfig +make -j$j ${debug} CONFIG=CustomGemminiSoCConfig diff --git a/scripts/build-verilator.sh b/scripts/build-verilator.sh index 965d335b..477c0910 100755 --- a/scripts/build-verilator.sh +++ b/scripts/build-verilator.sh @@ -4,21 +4,24 @@ help () { echo "Build a cycle-accurate Verilator simulator for RISCV Gemmini programs," echo 'matching `customConfig` in `configs/GemminiCustomConfigs.scala`.' echo - echo "Usage: $0 [-h|--help] [--debug]" + echo "Usage: $0 [-h|--help] [--debug] [-j [N]]" echo echo "Options:" echo " debug Builds a Verilator simulator which generates waveforms. Without" echo " this option, the simulator will not generate any waveforms." + echo " j [N] Allow N jobs at once. Default is 1." exit } show_help=0 debug="" +j="1" while [ $# -gt 0 ] ; do case $1 in -h | --help) show_help=1 ;; - --debug) debug="debug" + --debug) debug="debug" ;; + -j) j=$2; shift esac shift @@ -29,5 +32,5 @@ if [ $show_help -eq 1 ]; then fi cd ../../sims/verilator/ -make ${debug} CONFIG=CustomGemminiSoCConfig +make -j$j ${debug} CONFIG=CustomGemminiSoCConfig diff --git a/scripts/run-midas.sh b/scripts/run-midas.sh index 9bae1813..63616809 100755 --- a/scripts/run-midas.sh +++ b/scripts/run-midas.sh @@ -94,7 +94,7 @@ fi path="" suffix="" -for dir in bareMetalC mlps imagenet ; do +for dir in bareMetalC mlps imagenet transformers ; do if [ -f "software/gemmini-rocc-tests/build/${dir}/${binary}$default_suffix" ]; then path="${ROOT}/software/gemmini-rocc-tests/build/${dir}/" suffix=$default_suffix @@ -121,5 +121,6 @@ if [ ! -f ./${simulator}${DEBUG} ]; then fi ./${simulator}${DEBUG} ${PK} ${full_binary_path} ${waveform_flag} \ - +vcs+initreg+0 +vcs+initmem+0 +fesvr-step-size=128 +mm_relaxFunctionalModel_0=0 +mm_openPagePolicy_0=1 +mm_backendLatency_0=2 +mm_schedulerWindowSize_0=8 +mm_transactionQueueDepth_0=8 +mm_dramTimings_tAL_0=0 +mm_dramTimings_tCAS_0=14 +mm_dramTimings_tCMD_0=1 +mm_dramTimings_tCWD_0=10 +mm_dramTimings_tCCD_0=4 +mm_dramTimings_tFAW_0=25 +mm_dramTimings_tRAS_0=33 +mm_dramTimings_tREFI_0=7800 +mm_dramTimings_tRC_0=47 +mm_dramTimings_tRCD_0=14 +mm_dramTimings_tRFC_0=160 +mm_dramTimings_tRRD_0=8 +mm_dramTimings_tRP_0=14 +mm_dramTimings_tRTP_0=8 +mm_dramTimings_tRTRS_0=2 +mm_dramTimings_tWR_0=15 +mm_dramTimings_tWTR_0=8 +mm_rowAddr_offset_0=18 +mm_rowAddr_mask_0=65535 +mm_rankAddr_offset_0=16 +mm_rankAddr_mask_0=3 +mm_bankAddr_offset_0=13 +mm_bankAddr_mask_0=7 +mm_llc_wayBits_0=3 +mm_llc_setBits_0=12 +mm_llc_blockBits_0=7 +mm_llc_activeMSHRs_0=8 +shmemportname0=0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 +macaddr0=00:00:00:00:00:02 +niclog0=niclog0 +linklatency0=6405 +netbw0=100 +netburst0=8 +nic-loopback0 +tracefile=TRACEFILE +blkdev-in-mem0=128 +blkdev-log0=blkdev-log0 +autocounter-readrate=1000 +autocounter-filename=AUTOCOUNTERFILE +dramsim +max-cycles=100000000 \ + +vcs+initreg+0 +vcs+initmem+0 +fesvr-step-size=128 +mm_relaxFunctionalModel_0=0 +mm_openPagePolicy_0=1 +mm_backendLatency_0=2 +mm_schedulerWindowSize_0=8 +mm_transactionQueueDepth_0=8 +mm_dramTimings_tAL_0=0 +mm_dramTimings_tCAS_0=14 +mm_dramTimings_tCMD_0=1 +mm_dramTimings_tCWD_0=10 +mm_dramTimings_tCCD_0=4 +mm_dramTimings_tFAW_0=25 +mm_dramTimings_tRAS_0=33 +mm_dramTimings_tREFI_0=7800 +mm_dramTimings_tRC_0=47 +mm_dramTimings_tRCD_0=14 +mm_dramTimings_tRFC_0=160 +mm_dramTimings_tRRD_0=8 +mm_dramTimings_tRP_0=14 +mm_dramTimings_tRTP_0=8 +mm_dramTimings_tRTRS_0=2 +mm_dramTimings_tWR_0=15 +mm_dramTimings_tWTR_0=8 +mm_rowAddr_offset_0=18 +mm_rowAddr_mask_0=65535 +mm_rankAddr_offset_0=16 +mm_rankAddr_mask_0=3 +mm_bankAddr_offset_0=13 +mm_bankAddr_mask_0=7 +mm_llc_wayBits_0=3 +mm_llc_setBits_0=12 +mm_llc_blockBits_0=7 +mm_llc_activeMSHRs_0=8 +shmemportname0=0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 +macaddr0=00:00:00:00:00:02 +niclog0=niclog0 +linklatency0=6405 +netbw0=100 +netburst0=8 +nic-loopback0 +tracefile=TRACEFILE +blkdev-in-mem0=128 +blkdev-log0=blkdev-log0 +autocounter-readrate=1000 +autocounter-filename=AUTOCOUNTERFILE +max-cycles=100000000 \ + +dramsim +dramsim_ini_dir=/home/eecs/hngenc/chip/generators/testchipip/src/main/resources/dramsim2_ini \ 2>/dev/null diff --git a/scripts/run-spike.sh b/scripts/run-spike.sh index 00b5349f..1638b76c 100755 --- a/scripts/run-spike.sh +++ b/scripts/run-spike.sh @@ -60,7 +60,7 @@ fi path="" suffix="" -for dir in bareMetalC mlps imagenet ; do +for dir in bareMetalC mlps imagenet transformers ; do if [ -f "software/gemmini-rocc-tests/build/${dir}/${binary}$default_suffix" ]; then path="software/gemmini-rocc-tests/build/${dir}/" suffix=$default_suffix diff --git a/scripts/run-vcs.sh b/scripts/run-vcs.sh index 0fcbd9b1..40ce9bda 100755 --- a/scripts/run-vcs.sh +++ b/scripts/run-vcs.sh @@ -73,7 +73,7 @@ fi path="" suffix="" -for dir in bareMetalC mlps imagenet ; do +for dir in bareMetalC mlps imagenet transformers ; do if [ -f "software/gemmini-rocc-tests/build/${dir}/${binary}$default_suffix" ]; then path="${ROOT}/software/gemmini-rocc-tests/build/${dir}/" suffix=$default_suffix diff --git a/scripts/run-verilator.sh b/scripts/run-verilator.sh index 58d40d2b..b4f21458 100755 --- a/scripts/run-verilator.sh +++ b/scripts/run-verilator.sh @@ -73,7 +73,7 @@ fi path="" suffix="" -for dir in bareMetalC mlps imagenet ; do +for dir in bareMetalC mlps imagenet transformers ; do if [ -f "software/gemmini-rocc-tests/build/${dir}/${binary}$default_suffix" ]; then path="${ROOT}/software/gemmini-rocc-tests/build/${dir}/" suffix=$default_suffix diff --git a/software/gemmini-ort.json b/software/gemmini-ort.json index 7f561d79..c4a95253 100644 --- a/software/gemmini-ort.json +++ b/software/gemmini-ort.json @@ -52,7 +52,7 @@ "/output/mobilenet_optimized_ws_nhwc_out.txt" ], "overlay": "../onnxruntime-riscv/systolic_runner/imagenet_runner", - "rootfs-size": "1GiB", + "rootfs-size": "16GiB", "run": "run-ort.sh" } diff --git a/software/gemmini-rocc-tests b/software/gemmini-rocc-tests index e326e7c4..ae0cd823 160000 --- a/software/gemmini-rocc-tests +++ b/software/gemmini-rocc-tests @@ -1 +1 @@ -Subproject commit e326e7c43457ff08669fe88edcaa395d846474d8 +Subproject commit ae0cd8236d32fccf7197a7ac0634df5513cec4db diff --git a/software/gemmini-tests-interactive.json b/software/gemmini-tests-interactive.json index 8bd5f7ea..0fe52409 100644 --- a/software/gemmini-tests-interactive.json +++ b/software/gemmini-tests-interactive.json @@ -3,5 +3,6 @@ "workdir" : ".", "base" : "br-base.json", "overlay" : "overlay", - "host-init" : "host-init.sh" + "host-init" : "host-init.sh", + "rootfs-size" : "16GiB" } diff --git a/software/gemmini-tests.json b/software/gemmini-tests.json index 72f8661c..fc0e45a9 100644 --- a/software/gemmini-tests.json +++ b/software/gemmini-tests.json @@ -4,5 +4,6 @@ "base" : "br-base.json", "overlay" : "overlay", "host-init" : "host-init.sh", - "command": "/root/run-tests.sh" + "command": "/root/run-tests.sh", + "rootfs-size" : "16GiB" } diff --git a/software/onnxruntime-riscv b/software/onnxruntime-riscv index 7bbd0496..f6d2fc95 160000 --- a/software/onnxruntime-riscv +++ b/software/onnxruntime-riscv @@ -1 +1 @@ -Subproject commit 7bbd0496b579863c6906c0449932ac5ddc4c5357 +Subproject commit f6d2fc95463316ec47d7f832f35be03c26887922 diff --git a/src/main/scala/gemmini/AccumulatorMem.scala b/src/main/scala/gemmini/AccumulatorMem.scala index f8b62298..dd5ed821 100644 --- a/src/main/scala/gemmini/AccumulatorMem.scala +++ b/src/main/scala/gemmini/AccumulatorMem.scala @@ -5,30 +5,35 @@ import chisel3.util._ import Util._ -class AccumulatorReadReq[T <: Data](n: Int, shift_width: Int, scale_t: T) extends Bundle { +class AccumulatorReadReq[T <: Data: Arithmetic, U <: Data](n: Int, acc_t: T, scale_t: U) extends Bundle { val addr = UInt(log2Ceil(n).W) val scale = scale_t - val relu6_shift = UInt(shift_width.W) - val act = UInt(2.W) // TODO magic number + val igelu_qb = acc_t.cloneType + val igelu_qc = acc_t.cloneType + val iexp_qln2 = acc_t.cloneType + val iexp_qln2_inv = acc_t.cloneType + val act = UInt(Activation.bitwidth.W) // TODO magic number val full = Bool() // Whether or not we return the full bitwidth output val fromDMA = Bool() } -class AccumulatorReadResp[T <: Data: Arithmetic, U <: Data](fullDataType: Vec[Vec[T]], scale_t: U, shift_width: Int) extends Bundle { +class AccumulatorReadResp[T <: Data: Arithmetic, U <: Data](fullDataType: Vec[Vec[T]], scale_t: U) extends Bundle { val data = fullDataType.cloneType val fromDMA = Bool() val scale = scale_t.cloneType - val relu6_shift = UInt(shift_width.W) - val act = UInt(2.W) // TODO magic number - val acc_bank_id = UInt(2.W) // TODO don't hardcode + val igelu_qb = fullDataType.head.head.cloneType + val igelu_qc = fullDataType.head.head.cloneType + val iexp_qln2 = fullDataType.head.head.cloneType + val iexp_qln2_inv = fullDataType.head.head.cloneType + val act = UInt(Activation.bitwidth.W) // TODO magic number + val acc_bank_id = UInt(2.W) // TODO magic number } -class AccumulatorReadIO[T <: Data: Arithmetic, U <: Data](n: Int, shift_width: Int, fullDataType: Vec[Vec[T]], scale_t: U) extends Bundle { - val req = Decoupled(new AccumulatorReadReq[U](n, shift_width, scale_t)) - val resp = Flipped(Decoupled(new AccumulatorReadResp[T, U](fullDataType, scale_t, shift_width))) - +class AccumulatorReadIO[T <: Data: Arithmetic, U <: Data](n: Int, fullDataType: Vec[Vec[T]], scale_t: U) extends Bundle { + val req = Decoupled(new AccumulatorReadReq[T, U](n, fullDataType.head.head.cloneType, scale_t)) + val resp = Flipped(Decoupled(new AccumulatorReadResp[T, U](fullDataType, scale_t))) } class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends Bundle { @@ -36,15 +41,13 @@ class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends val data = t.cloneType val acc = Bool() val mask = Vec(t.getWidth / 8, Bool()) // TODO Use aligned_to here - // val current_waddr = Flipped(Valid(UInt(log2Ceil(n).W))) // This is the raddr that is being fed into the SRAM right now - } class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], scale_t: U, acc_sub_banks: Int, use_shared_ext_mem: Boolean ) extends Bundle { - val read = Flipped(new AccumulatorReadIO(n, log2Ceil(t.head.head.getWidth), t, scale_t)) + val read = Flipped(new AccumulatorReadIO(n, t, scale_t)) val write = Flipped(Decoupled(new AccumulatorWriteReq(n, t))) val ext_mem = if (use_shared_ext_mem) Some(Vec(acc_sub_banks, new ExtMemIO)) else None @@ -55,7 +58,6 @@ class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]] val op2 = Output(t.cloneType) val sum = Input(t.cloneType) } - } class AccPipe[T <: Data : Arithmetic](latency: Int, t: T)(implicit ev: Arithmetic[T]) extends Module { @@ -98,8 +100,6 @@ class AccumulatorMem[T <: Data, U <: Data]( // to it, then we might not get the written data. We might need some kind of cooldown counter after addresses in the // accumulator have been written to for configurations with such small matrices - // TODO Refuse a read from an address which has only just been written to - // TODO make a new aligned_to variable specifically for AccumulatorMem. We should assume that inputs are at least // accType.getWidth/8 aligned, because it won't make sense to do matrix additions directly in the DMA otherwise. @@ -291,7 +291,7 @@ class AccumulatorMem[T <: Data, U <: Data]( } } - val q = Module(new Queue(new AccumulatorReadResp(t, scale_t, log2Ceil(t.head.head.getWidth)), 1, true, true)) + val q = Module(new Queue(new AccumulatorReadResp(t, scale_t), 1, true, true)) q.io.enq.bits.data := rdata_for_read_resp if (is_dummy) { @@ -300,7 +300,10 @@ class AccumulatorMem[T <: Data, U <: Data]( } q.io.enq.bits.scale := RegNext(io.read.req.bits.scale) - q.io.enq.bits.relu6_shift := RegNext(io.read.req.bits.relu6_shift) + q.io.enq.bits.igelu_qb := RegNext(io.read.req.bits.igelu_qb) + q.io.enq.bits.igelu_qc := RegNext(io.read.req.bits.igelu_qc) + q.io.enq.bits.iexp_qln2 := RegNext(io.read.req.bits.iexp_qln2) + q.io.enq.bits.iexp_qln2_inv := RegNext(io.read.req.bits.iexp_qln2_inv) q.io.enq.bits.act := RegNext(io.read.req.bits.act) q.io.enq.bits.fromDMA := RegNext(io.read.req.bits.fromDMA) q.io.enq.bits.acc_bank_id := DontCare @@ -310,7 +313,10 @@ class AccumulatorMem[T <: Data, U <: Data]( io.read.resp.bits.data := p.bits.data io.read.resp.bits.fromDMA := p.bits.fromDMA - io.read.resp.bits.relu6_shift := p.bits.relu6_shift + io.read.resp.bits.igelu_qb := p.bits.igelu_qb + io.read.resp.bits.igelu_qc := p.bits.igelu_qc + io.read.resp.bits.iexp_qln2 := p.bits.iexp_qln2 + io.read.resp.bits.iexp_qln2_inv := p.bits.iexp_qln2_inv io.read.resp.bits.act := p.bits.act io.read.resp.bits.scale := p.bits.scale io.read.resp.bits.acc_bank_id := DontCare // This is set in Scratchpad diff --git a/src/main/scala/gemmini/AccumulatorScale.scala b/src/main/scala/gemmini/AccumulatorScale.scala index 2d23af1d..1fdd15fa 100644 --- a/src/main/scala/gemmini/AccumulatorScale.scala +++ b/src/main/scala/gemmini/AccumulatorScale.scala @@ -1,16 +1,16 @@ + package gemmini import chisel3._ import chisel3.util._ - import Util._ -class AccumulatorReadRespWithFullData[T <: Data: Arithmetic, U <: Data](fullDataType: Vec[Vec[T]], scale_t: U, shift_width: Int) extends Bundle { - val resp = new AccumulatorReadResp(fullDataType, scale_t, shift_width) +class AccumulatorReadRespWithFullData[T <: Data: Arithmetic, U <: Data](fullDataType: Vec[Vec[T]], scale_t: U) + extends Bundle { + val resp = new AccumulatorReadResp(fullDataType, scale_t) val full_data = fullDataType.cloneType } - class AccumulatorScaleResp[T <: Data: Arithmetic](fullDataType: Vec[Vec[T]], rDataType: Vec[Vec[T]]) extends Bundle { val full_data = fullDataType.cloneType val data = rDataType.cloneType @@ -19,26 +19,33 @@ class AccumulatorScaleResp[T <: Data: Arithmetic](fullDataType: Vec[Vec[T]], rDa } class AccumulatorScaleIO[T <: Data: Arithmetic, U <: Data]( - fullDataType: Vec[Vec[T]], scale_t: U, shift_width: Int, + fullDataType: Vec[Vec[T]], scale_t: U, rDataType: Vec[Vec[T]] ) extends Bundle { - val in = Flipped(Decoupled(new AccumulatorReadResp[T,U](fullDataType, scale_t, shift_width))) + val in = Flipped(Decoupled(new NormalizedOutput[T,U](fullDataType, scale_t))) val out = Decoupled(new AccumulatorScaleResp[T](fullDataType, rDataType)) } class AccScaleDataWithIndex[T <: Data: Arithmetic, U <: Data](t: T, u: U) extends Bundle { - val shift_width = log2Ceil(t.getWidth) - val scale = u.cloneType val act = UInt(2.W) // TODO magic number - val relu6_shift = UInt(shift_width.W) + val igelu_qb = t.cloneType + val igelu_qc = t.cloneType + val iexp_qln2 = t.cloneType + val iexp_qln2_inv = t.cloneType + val mean = t.cloneType + val max = t.cloneType + val inv_stddev = u.cloneType + val inv_sum_exp = u.cloneType val data = t.cloneType val full_data = t.cloneType val id = UInt(2.W) // TODO hardcoded val index = UInt() } -class AccScalePipe[T <: Data : Arithmetic, U <: Data](t: T, rDataType: Vec[Vec[T]], scale_func: (T, U) => T, scale_t: U, latency: Int, has_nonlinear_activations: Boolean)(implicit ev: Arithmetic[T]) extends Module { +class AccScalePipe[T <: Data, U <: Data](t: T, rDataType: Vec[Vec[T]], scale_func: (T, U) => T, scale_t: U, + latency: Int, has_nonlinear_activations: Boolean, has_normalizations: Boolean) + (implicit ev: Arithmetic[T]) extends Module { val u = scale_t val io = IO(new Bundle { val in = Input(Valid(new AccScaleDataWithIndex(t, u)(ev))) @@ -47,68 +54,97 @@ class AccScalePipe[T <: Data : Arithmetic, U <: Data](t: T, rDataType: Vec[Vec[T import ev._ val out = WireInit(io.in) - val e_scaled = scale_func(io.in.bits.data, io.in.bits.scale) + val e = io.in.bits.data + + val e_act = MuxCase(e, Seq( + (has_nonlinear_activations.B && io.in.bits.act === Activation.RELU) -> e.relu, + (has_nonlinear_activations.B && has_normalizations.B && io.in.bits.act === Activation.LAYERNORM) -> + (e - io.in.bits.mean).mult_with_reciprocal(io.in.bits.inv_stddev), + (has_nonlinear_activations.B && has_normalizations.B && io.in.bits.act === Activation.IGELU) -> + AccumulatorScale.igelu(e, io.in.bits.igelu_qb, io.in.bits.igelu_qc), + (has_nonlinear_activations.B && has_normalizations.B && io.in.bits.act === Activation.SOFTMAX) -> + scale_func( + AccumulatorScale.iexp(e - io.in.bits.max, io.in.bits.iexp_qln2, io.in.bits.iexp_qln2_inv, io.in.bits.igelu_qb, io.in.bits.igelu_qc), + io.in.bits.inv_sum_exp.asTypeOf(scale_t)), + )) + + val e_scaled = scale_func(e_act, io.in.bits.scale) val e_clipped = e_scaled.clippedToWidthOf(rDataType.head.head) - val e_act = MuxCase(e_clipped, Seq( - (has_nonlinear_activations.B && io.in.bits.act === Activation.RELU) -> e_clipped.relu, - (has_nonlinear_activations.B && io.in.bits.act === Activation.RELU6) -> e_clipped.relu6(io.in.bits.relu6_shift))) - out.bits.data := e_act + out.bits.data := e_clipped io.out := Pipe(out, latency) } -class AccumulatorScale[T <: Data: Arithmetic, U <: Data]( +class AccumulatorScale[T <: Data, U <: Data]( fullDataType: Vec[Vec[T]], rDataType: Vec[Vec[T]], - scale_t: U, shift_width: Int, + scale_t: U, read_small_data: Boolean, read_full_data: Boolean, scale_func: (T, U) => T, num_scale_units: Int, latency: Int, - has_nonlinear_activations: Boolean)(implicit ev: Arithmetic[T]) extends Module { + has_nonlinear_activations: Boolean, has_normalizations: Boolean)(implicit ev: Arithmetic[T]) extends Module { import ev._ val io = IO(new AccumulatorScaleIO[T,U]( - fullDataType, scale_t, shift_width, rDataType + fullDataType, scale_t, rDataType )(ev)) - val t = io.in.bits.data(0)(0).cloneType + val t = io.in.bits.acc_read_resp.data(0)(0).cloneType + val acc_read_data = io.in.bits.acc_read_resp.data val out = Wire(Decoupled(new AccumulatorScaleResp[T]( fullDataType, rDataType)(ev))) if (num_scale_units == -1) { - val in = Wire(Decoupled(new AccumulatorReadRespWithFullData(fullDataType, scale_t, shift_width)(ev))) + val data = io.in.bits.acc_read_resp.data + val act = io.in.bits.acc_read_resp.act + val igelu_qb = io.in.bits.acc_read_resp.igelu_qb + val igelu_qc = io.in.bits.acc_read_resp.igelu_qc + val iexp_qln2 = io.in.bits.acc_read_resp.iexp_qln2 + val iexp_qln2_inv = io.in.bits.acc_read_resp.iexp_qln2_inv + val scale = io.in.bits.acc_read_resp.scale + + val activated_data = VecInit(data.map(v => VecInit(v.map { e => + val e_act = MuxCase(e, Seq( + (has_nonlinear_activations.B && act === Activation.RELU) -> e.relu, + (has_nonlinear_activations.B && has_normalizations.B && act === Activation.LAYERNORM) -> + (e - io.in.bits.mean).mult_with_reciprocal(io.in.bits.inv_stddev), + (has_nonlinear_activations.B && has_normalizations.B && act === Activation.IGELU) -> + AccumulatorScale.igelu(e, igelu_qb, igelu_qc), + (has_nonlinear_activations.B && has_normalizations.B && act === Activation.SOFTMAX) -> + scale_func( + AccumulatorScale.iexp(e - io.in.bits.max, iexp_qln2, iexp_qln2_inv, igelu_qb, igelu_qc), + io.in.bits.inv_sum_exp.asTypeOf(scale_t)), + )) + + val e_scaled = scale_func(e_act, scale) + val e_clipped = e_scaled.clippedToWidthOf(rDataType.head.head) + + e_clipped + }))) + + val in = Wire(Decoupled(new AccumulatorReadRespWithFullData(fullDataType, scale_t)(ev))) in.valid := io.in.valid io.in.ready := in.ready - in.bits.resp := io.in.bits - in.bits.full_data := io.in.bits.data - - val pipe_out = Pipeline(in, latency, Seq.fill(latency)((x: AccumulatorReadRespWithFullData[T,U]) => x) :+ { - x: AccumulatorReadRespWithFullData[T,U] => - val activated_rdata = VecInit(x.resp.data.map(v => VecInit(v.map { e => - val e_scaled = scale_func(e, x.resp.scale) - val e_clipped = e_scaled.clippedToWidthOf(rDataType.head.head) - val e_act = MuxCase(e_clipped, Seq( - (x.resp.act === Activation.RELU) -> e_clipped.relu, - (x.resp.act === Activation.RELU6) -> e_clipped.relu6(x.resp.relu6_shift))) - - e_act - }))) - val result = WireInit(x) - result.resp.data := activated_rdata - result - }) - out.valid := pipe_out.valid + in.bits.resp := io.in.bits.acc_read_resp + in.bits.full_data := acc_read_data + in.bits.resp.data := activated_data + + val pipe_out = Pipeline(in, latency) + + out.valid := pipe_out.valid pipe_out.ready := out.ready out.bits.full_data := pipe_out.bits.full_data out.bits.data := pipe_out.bits.resp.data out.bits.fromDMA := pipe_out.bits.resp.fromDMA out.bits.acc_bank_id := pipe_out.bits.resp.acc_bank_id } else { - val width = io.in.bits.data.size * io.in.bits.data(0).size + val width = acc_read_data.size * acc_read_data(0).size val nEntries = 3 - val regs = Reg(Vec(nEntries, Valid(new AccumulatorReadResp[T,U]( - fullDataType, scale_t, shift_width)(ev)))) + /*val regs = Reg(Vec(nEntries, Valid(new AccumulatorReadResp[T,U]( + fullDataType, scale_t)(ev))))*/ + val regs = Reg(Vec(nEntries, Valid(new NormalizedOutput[T,U]( + fullDataType, scale_t)(ev)))) val out_regs = Reg(Vec(nEntries, new AccumulatorScaleResp[T]( fullDataType, rDataType)(ev))) @@ -124,7 +160,7 @@ class AccumulatorScale[T <: Data: Arithmetic, U <: Data]( regs(i).valid := false.B } } - head_oh := (head_oh << 1) | head_oh(nEntries-1) + head_oh := (head_oh << 1).asUInt() | head_oh(nEntries-1) } io.in.ready := !Mux1H(tail_oh.asBools, regs.map(_.valid)) || (tail_oh === head_oh && out.fire) @@ -133,13 +169,13 @@ class AccumulatorScale[T <: Data: Arithmetic, U <: Data]( when (tail_oh(i)) { regs(i).valid := true.B regs(i).bits := io.in.bits - out_regs(i).fromDMA := io.in.bits.fromDMA - out_regs(i).acc_bank_id := io.in.bits.acc_bank_id + out_regs(i).fromDMA := io.in.bits.acc_read_resp.fromDMA + out_regs(i).acc_bank_id := io.in.bits.acc_read_resp.acc_bank_id fired_masks(i).foreach(_ := false.B) completed_masks(i).foreach(_ := false.B) } } - tail_oh := (tail_oh << 1) | tail_oh(nEntries-1) + tail_oh := (tail_oh << 1).asUInt() | tail_oh(nEntries-1) } val inputs = Seq.fill(width*nEntries) { Wire(Decoupled(new AccScaleDataWithIndex(t, scale_t)(ev))) } @@ -147,12 +183,22 @@ class AccumulatorScale[T <: Data: Arithmetic, U <: Data]( for (i <- 0 until nEntries) { for (w <- 0 until width) { val input = inputs(i*width+w) + + val acc_read_resp = regs(i).bits.acc_read_resp + input.valid := regs(i).valid && !fired_masks(i)(w) - input.bits.data := regs(i).bits.data(w / io.in.bits.data(0).size)(w % io.in.bits.data(0).size) - input.bits.full_data := regs(i).bits.data(w / io.in.bits.data(0).size)(w % io.in.bits.data(0).size) - input.bits.scale := regs(i).bits.scale - input.bits.act := regs(i).bits.act - input.bits.relu6_shift := regs(i).bits.relu6_shift + input.bits.data := acc_read_resp.data(w / acc_read_data(0).size)(w % acc_read_data(0).size) + input.bits.full_data := acc_read_resp.data(w / acc_read_data(0).size)(w % acc_read_data(0).size) + input.bits.scale := acc_read_resp.scale + input.bits.act := acc_read_resp.act + input.bits.igelu_qb := acc_read_resp.igelu_qb + input.bits.igelu_qc := acc_read_resp.igelu_qc + input.bits.iexp_qln2 := acc_read_resp.iexp_qln2 + input.bits.iexp_qln2_inv := acc_read_resp.iexp_qln2_inv + input.bits.mean := regs(i).bits.mean + input.bits.max := regs(i).bits.max + input.bits.inv_stddev := regs(i).bits.inv_stddev + input.bits.inv_sum_exp := regs(i).bits.inv_sum_exp input.bits.id := i.U input.bits.index := w.U when (input.fire) { @@ -171,15 +217,16 @@ class AccumulatorScale[T <: Data: Arithmetic, U <: Data]( when (reset.asBool) { arbOut.valid := false.B } - val pipe = Module(new AccScalePipe(t, rDataType, scale_func, scale_t, latency, has_nonlinear_activations)(ev, ev)) + val pipe = Module(new AccScalePipe(t, rDataType, scale_func, scale_t, latency, has_nonlinear_activations, + has_normalizations)) pipe.io.in := arbOut val pipe_out = pipe.io.out for (j <- 0 until nEntries) { for (w <- 0 until width) { if ((j*width+w) % num_scale_units == i) { - val id0 = w % io.in.bits.data(0).size - val id1 = w / io.in.bits.data(0).size + val id0 = w % acc_read_data(0).size + val id1 = w / acc_read_data(0).size when (pipe_out.fire && pipe_out.bits.id === j.U && pipe_out.bits.index === w.U) { out_regs(j).data (id1)(id0) := pipe_out.bits.data out_regs(j).full_data(id1)(id0) := pipe_out.bits.full_data @@ -205,6 +252,40 @@ class AccumulatorScale[T <: Data: Arithmetic, U <: Data]( io.out.bits.full_data := out.bits.full_data else io.out.bits.full_data := DontCare - } +object AccumulatorScale { + def igelu[T <: Data](q: T, qb: T, qc: T)(implicit ev: Arithmetic[T]): T = { + import ev._ + + val zero = q.zero + val one = q.identity + def neg(x: T) = zero-x + + val q_sign = Mux(q.zero > q, neg(one), one) + val q_abs = Mux(q.zero > q, neg(q), q) + val q_clipped = Mux(q_abs > neg(qb), neg(qb), q_abs) + val q_poly = qc.mac(q_clipped + qb, q_clipped + qb).withWidthOf(q) + val q_erf = (q_sign * q_poly).withWidthOf(q) + (q * (q_erf + qc)).withWidthOf(q) + } + + def iexp[T <: Data](q: T, qln2: T, qln2_inv: T, qb: T, qc: T)(implicit ev: Arithmetic[T]): T = { + import ev._ + + val zero = q.zero + def neg(x: T) = zero-x + + // qln2_inv needs scale to be + // 1 / (2 ** 16) / S + + // qln2_inv / S / (2 ** 16) = 1 / ln2 + // q * qln2_inv = x / S / ln2 * S * (2 ** 16) = x / ln2 * (2 ** 16) + val neg_q_iexp = neg(q) + val z_iexp = (neg_q_iexp * qln2_inv).asUInt().do_>>(16).asTypeOf(q) // q is non-positive + val qp_iexp = q.mac(z_iexp, qln2).withWidthOf(q) + val q_poly_iexp = qc.mac(qp_iexp + qb, qp_iexp + qb).withWidthOf(q) + // we dont want a rounding shift + (q_poly_iexp.asUInt().do_>>(z_iexp.asUInt()(5, 0))).asTypeOf(q) + }} + diff --git a/src/main/scala/gemmini/Activation.scala b/src/main/scala/gemmini/Activation.scala index ed7df57f..1b7d94e6 100644 --- a/src/main/scala/gemmini/Activation.scala +++ b/src/main/scala/gemmini/Activation.scala @@ -5,5 +5,9 @@ import chisel3._ object Activation { val NONE = 0.U val RELU = 1.U - val RELU6 = 2.U + val LAYERNORM = 2.U + val IGELU = 3.U + val SOFTMAX = 4.U + + val bitwidth = 3 } diff --git a/src/main/scala/gemmini/Arithmetic.scala b/src/main/scala/gemmini/Arithmetic.scala index 4f8e9343..cdd36396 100644 --- a/src/main/scala/gemmini/Arithmetic.scala +++ b/src/main/scala/gemmini/Arithmetic.scala @@ -32,14 +32,21 @@ abstract class ArithmeticOps[T <: Data](self: T) { def *(t: T): T def mac(m1: T, m2: T): T // Returns (m1 * m2 + self) def +(t: T): T + def -(t: T): T def >>(u: UInt): T // This is a rounding shift! Rounds away from 0 def >(t: T): Bool def identity: T def withWidthOf(t: T): T def clippedToWidthOf(t: T): T // Like "withWidthOf", except that it saturates def relu: T - def relu6(shift: UInt): T def zero: T + def minimum: T + + // Optional parameters, which only need to be defined if you want to enable various optimizations for transformers + def divider(denom_t: UInt): Option[(DecoupledIO[UInt], DecoupledIO[T])] = None + def sqrt: Option[(DecoupledIO[UInt], DecoupledIO[T])] = None + def reciprocal[U <: Data](u: U): Option[(DecoupledIO[UInt], DecoupledIO[U])] = None + def mult_with_reciprocal[U <: Data](reciprocal: U) = self } object Arithmetic { @@ -48,6 +55,7 @@ object Arithmetic { override def *(t: UInt) = self * t override def mac(m1: UInt, m2: UInt) = m1 * m2 + self override def +(t: UInt) = self + t + override def -(t: UInt) = self - t override def >>(u: UInt) = { // The equation we use can be found here: https://riscv.github.io/documents/riscv-v-spec/#_vector_fixed_point_rounding_mode_register_vxrm @@ -72,15 +80,10 @@ object Arithmetic { } override def relu: UInt = self - override def relu6(shift: UInt): UInt = { - val max6 = (6.U << shift).asUInt() - val maxwidth = ((1 << (self.getWidth-1))-1).U - val max = Mux(max6 > maxwidth, maxwidth, max6)(self.getWidth-1, 0).asUInt() - Mux(self < max, self, max) - } override def zero: UInt = 0.U override def identity: UInt = 1.U + override def minimum: UInt = 0.U } } @@ -89,6 +92,7 @@ object Arithmetic { override def *(t: SInt) = self * t override def mac(m1: SInt, m2: SInt) = m1 * m2 + self override def +(t: SInt) = self + t + override def -(t: SInt) = self - t override def >>(u: UInt) = { // The equation we use can be found here: https://riscv.github.io/documents/riscv-v-spec/#_vector_fixed_point_rounding_mode_register_vxrm @@ -122,15 +126,204 @@ object Arithmetic { } override def relu: SInt = Mux(self >= 0.S, self, 0.S) - override def relu6(shift: UInt): SInt = { - val max6 = (6.S << shift).asSInt() - val maxwidth = ((1 << (self.getWidth-1))-1).S - val max = Mux(max6 > maxwidth, maxwidth, max6)(self.getWidth-1, 0).asSInt() - MuxCase(self, Seq((self < 0.S) -> 0.S, (self > max) -> max)) - } override def zero: SInt = 0.S override def identity: SInt = 1.S + override def minimum: SInt = (-(1 << (self.getWidth-1))).S + + override def divider(denom_t: UInt): Option[(DecoupledIO[UInt], DecoupledIO[SInt])] = { + // TODO this uses a floating point divider, but we should use an integer divider instead + + val input = Wire(Decoupled(denom_t.cloneType)) + val output = Wire(Decoupled(self.cloneType)) + + // We translate our integer to floating-point form so that we can use the hardfloat divider + val expWidth = log2Up(self.getWidth) + 1 + val sigWidth = self.getWidth + + def sin_to_float(x: SInt) = { + val in_to_rec_fn = Module(new INToRecFN(intWidth = self.getWidth, expWidth, sigWidth)) + in_to_rec_fn.io.signedIn := true.B + in_to_rec_fn.io.in := x.asUInt() + in_to_rec_fn.io.roundingMode := consts.round_minMag // consts.round_near_maxMag + in_to_rec_fn.io.detectTininess := consts.tininess_afterRounding + + in_to_rec_fn.io.out + } + + def uin_to_float(x: UInt) = { + val in_to_rec_fn = Module(new INToRecFN(intWidth = self.getWidth, expWidth, sigWidth)) + in_to_rec_fn.io.signedIn := false.B + in_to_rec_fn.io.in := x + in_to_rec_fn.io.roundingMode := consts.round_minMag // consts.round_near_maxMag + in_to_rec_fn.io.detectTininess := consts.tininess_afterRounding + + in_to_rec_fn.io.out + } + + def float_to_in(x: UInt) = { + val rec_fn_to_in = Module(new RecFNToIN(expWidth = expWidth, sigWidth, self.getWidth)) + rec_fn_to_in.io.signedOut := true.B + rec_fn_to_in.io.in := x + rec_fn_to_in.io.roundingMode := consts.round_minMag // consts.round_near_maxMag + + rec_fn_to_in.io.out.asSInt() + } + + val self_rec = sin_to_float(self) + val denom_rec = uin_to_float(input.bits) + + // Instantiate the hardloat divider + val divider = Module(new DivSqrtRecFN_small(expWidth, sigWidth, 0)) + + input.ready := divider.io.inReady + divider.io.inValid := input.valid + divider.io.sqrtOp := false.B + divider.io.a := self_rec + divider.io.b := denom_rec + divider.io.roundingMode := consts.round_minMag + divider.io.detectTininess := consts.tininess_afterRounding + + output.valid := divider.io.outValid_div + output.bits := float_to_in(divider.io.out) + + assert(!output.valid || output.ready) + + Some((input, output)) + } + + override def sqrt: Option[(DecoupledIO[UInt], DecoupledIO[SInt])] = { + // TODO this uses a floating point divider, but we should use an integer divider instead + + val input = Wire(Decoupled(UInt(0.W))) + val output = Wire(Decoupled(self.cloneType)) + + input.bits := DontCare + + // We translate our integer to floating-point form so that we can use the hardfloat divider + val expWidth = log2Up(self.getWidth) + 1 + val sigWidth = self.getWidth + + def in_to_float(x: SInt) = { + val in_to_rec_fn = Module(new INToRecFN(intWidth = self.getWidth, expWidth, sigWidth)) + in_to_rec_fn.io.signedIn := true.B + in_to_rec_fn.io.in := x.asUInt() + in_to_rec_fn.io.roundingMode := consts.round_minMag // consts.round_near_maxMag + in_to_rec_fn.io.detectTininess := consts.tininess_afterRounding + + in_to_rec_fn.io.out + } + + def float_to_in(x: UInt) = { + val rec_fn_to_in = Module(new RecFNToIN(expWidth = expWidth, sigWidth, self.getWidth)) + rec_fn_to_in.io.signedOut := true.B + rec_fn_to_in.io.in := x + rec_fn_to_in.io.roundingMode := consts.round_minMag // consts.round_near_maxMag + + rec_fn_to_in.io.out.asSInt() + } + + val self_rec = in_to_float(self) + + // Instantiate the hardloat sqrt + val sqrter = Module(new DivSqrtRecFN_small(expWidth, sigWidth, 0)) + + input.ready := sqrter.io.inReady + sqrter.io.inValid := input.valid + sqrter.io.sqrtOp := true.B + sqrter.io.a := self_rec + sqrter.io.b := DontCare + sqrter.io.roundingMode := consts.round_minMag + sqrter.io.detectTininess := consts.tininess_afterRounding + + output.valid := sqrter.io.outValid_sqrt + output.bits := float_to_in(sqrter.io.out) + + assert(!output.valid || output.ready) + + Some((input, output)) + } + + override def reciprocal[U <: Data](u: U): Option[(DecoupledIO[UInt], DecoupledIO[U])] = u match { + case Float(expWidth, sigWidth) => + val input = Wire(Decoupled(UInt(0.W))) + val output = Wire(Decoupled(u.cloneType)) + + input.bits := DontCare + + // We translate our integer to floating-point form so that we can use the hardfloat divider + def in_to_float(x: SInt) = { + val in_to_rec_fn = Module(new INToRecFN(intWidth = self.getWidth, expWidth, sigWidth)) + in_to_rec_fn.io.signedIn := true.B + in_to_rec_fn.io.in := x.asUInt() + in_to_rec_fn.io.roundingMode := consts.round_near_even // consts.round_near_maxMag + in_to_rec_fn.io.detectTininess := consts.tininess_afterRounding + + in_to_rec_fn.io.out + } + + val self_rec = in_to_float(self) + val one_rec = in_to_float(1.S) + + // Instantiate the hardloat divider + val divider = Module(new DivSqrtRecFN_small(expWidth, sigWidth, 0)) + + input.ready := divider.io.inReady + divider.io.inValid := input.valid + divider.io.sqrtOp := false.B + divider.io.a := one_rec + divider.io.b := self_rec + divider.io.roundingMode := consts.round_near_even + divider.io.detectTininess := consts.tininess_afterRounding + + output.valid := divider.io.outValid_div + output.bits := fNFromRecFN(expWidth, sigWidth, divider.io.out).asTypeOf(u) + + assert(!output.valid || output.ready) + + Some((input, output)) + + case _ => None + } + + override def mult_with_reciprocal[U <: Data](reciprocal: U): SInt = reciprocal match { + case recip @ Float(expWidth, sigWidth) => + def in_to_float(x: SInt) = { + val in_to_rec_fn = Module(new INToRecFN(intWidth = self.getWidth, expWidth, sigWidth)) + in_to_rec_fn.io.signedIn := true.B + in_to_rec_fn.io.in := x.asUInt() + in_to_rec_fn.io.roundingMode := consts.round_near_even // consts.round_near_maxMag + in_to_rec_fn.io.detectTininess := consts.tininess_afterRounding + + in_to_rec_fn.io.out + } + + def float_to_in(x: UInt) = { + val rec_fn_to_in = Module(new RecFNToIN(expWidth = expWidth, sigWidth, self.getWidth)) + rec_fn_to_in.io.signedOut := true.B + rec_fn_to_in.io.in := x + rec_fn_to_in.io.roundingMode := consts.round_minMag + + rec_fn_to_in.io.out.asSInt() + } + + val self_rec = in_to_float(self) + val reciprocal_rec = recFNFromFN(expWidth, sigWidth, recip.bits) + + // Instantiate the hardloat divider + val muladder = Module(new MulAddRecFN(expWidth, sigWidth)) + muladder.io.op := 0.U + muladder.io.roundingMode := consts.round_near_even + muladder.io.detectTininess := consts.tininess_afterRounding + + muladder.io.a := self_rec + muladder.io.b := reciprocal_rec + muladder.io.c := 0.U + + float_to_in(muladder.io.out) + + case _ => self + } } } @@ -239,6 +432,12 @@ object Arithmetic { result } + override def -(t: Float): Float = { + val t_sgn = t.bits(t.getWidth-1) + val neg_t = Cat(~t_sgn, t.bits(t.getWidth-2,0)).asTypeOf(t) + self + neg_t + } + override def >>(u: UInt): Float = { // Recode self val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) @@ -322,55 +521,9 @@ object Arithmetic { result } - override def relu6(shift: UInt): Float = { - // Get a constant 6 as a float - val in_to_rec_fn = Module(new INToRecFN(log2Up(6+1), self.expWidth, self.sigWidth)) - in_to_rec_fn.io.signedIn := false.B - in_to_rec_fn.io.in := 6.U - in_to_rec_fn.io.roundingMode := consts.round_near_even // consts.round_near_maxMag - in_to_rec_fn.io.detectTininess := consts.tininess_afterRounding - - val six_rec = in_to_rec_fn.io.out - - // Get 2^shift as a float - val shift_exp = self.bias.U(self.expWidth.W) + shift - val shift_fn = Cat(0.U(1.W), shift_exp, 0.U((self.sigWidth-1).W)) - val shift_rec = recFNFromFN(self.expWidth, self.sigWidth, shift_fn) - - // Get 6*(2^shift) as a float - val muladder = Module(new MulAddRecFN(self.expWidth, self.sigWidth)) - - muladder.io.op := 0.U - muladder.io.roundingMode := consts.round_near_even // consts.round_near_maxMag - muladder.io.detectTininess := consts.tininess_afterRounding - - muladder.io.a := six_rec - muladder.io.b := shift_rec - muladder.io.c := 0.U - - val shifted_rec = muladder.io.out - - // Now, compare self and 6*(2^shift) to calculate the activation function - val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) - val self_raw = rawFloatFromFN(self.expWidth, self.sigWidth, self.bits) - - val comparer = Module(new CompareRecFN(self.expWidth, self.sigWidth)) - comparer.io.a := self_rec - comparer.io.b := shifted_rec - comparer.io.signaling := false.B - - val larger_than_six = comparer.io.gt - - val result_rec = Mux(!self_raw.isZero && self_raw.sign, 0.U, - Mux(larger_than_six, shifted_rec, self_rec)) - - val result = Wire(Float(self.expWidth, self.sigWidth)) - result.bits := fNFromRecFN(self.expWidth, self.sigWidth, result_rec) - result - } - override def zero: Float = 0.U.asTypeOf(self) override def identity: Float = Cat(0.U(2.W), ~(0.U((self.expWidth-1).W)), 0.U((self.sigWidth-1).W)).asTypeOf(self) + override def minimum: Float = Cat(1.U, ~(0.U(self.expWidth.W)), 0.U((self.sigWidth-1).W)).asTypeOf(self) } } @@ -379,14 +532,15 @@ object Arithmetic { override def *(t: DummySInt) = self.dontCare override def mac(m1: DummySInt, m2: DummySInt) = self.dontCare override def +(t: DummySInt) = self.dontCare + override def -(t: DummySInt) = self.dontCare override def >>(t: UInt) = self.dontCare override def >(t: DummySInt): Bool = false.B override def identity = self.dontCare override def withWidthOf(t: DummySInt) = self.dontCare override def clippedToWidthOf(t: DummySInt) = self.dontCare override def relu = self.dontCare - override def relu6(shift: UInt) = self.dontCare override def zero = self.dontCare + override def minimum: DummySInt = self.dontCare } } } diff --git a/src/main/scala/gemmini/BeatMerger.scala b/src/main/scala/gemmini/BeatMerger.scala index a845327b..a6a67dab 100644 --- a/src/main/scala/gemmini/BeatMerger.scala +++ b/src/main/scala/gemmini/BeatMerger.scala @@ -1,3 +1,4 @@ + package gemmini import chisel3._ @@ -26,7 +27,8 @@ class BeatMergerOut(val spadWidth: Int, val accWidth: Int, val spadRows: Int, va maxReqBytes: in bytes aligned_to: in bytes */ -class BeatMerger[U <: Data](beatBits: Int, maxShift: Int, spadWidth: Int, accWidth: Int, spadRows: Int, accRows: Int, maxReqBytes: Int, alignedTo: Int, meshRows: Int, mvin_scale_t_bits: Int, nCmds: Int) +class BeatMerger[U <: Data](beatBits: Int, maxShift: Int, spadWidth: Int, accWidth: Int, spadRows: Int, accRows: Int, + maxReqBytes: Int, alignedTo: Int, meshRows: Int, mvin_scale_t_bits: Int, nCmds: Int) extends Module { val io = IO(new Bundle { val req = Flipped(Decoupled(new XactTrackerEntry(maxShift, spadWidth, accWidth, spadRows, accRows, maxReqBytes, mvin_scale_t_bits, nCmds))) @@ -75,9 +77,10 @@ class BeatMerger[U <: Data](beatBits: Int, maxShift: Int, spadWidth: Int, accWid val total_bytes_sent = req.bits.spad_row_offset + bytesSent Mux(req.bits.has_acc_bitwidth, // We only add "if" statements here to satisfy the Verilator linter. The code would be cleaner without the - // "if" condition and the "else" clause - if (total_bytes_sent.getWidth >= log2Up(accWidthBytes+1)) total_bytes_sent / accWidthBytes.U else 0.U, - if (total_bytes_sent.getWidth >= log2Up(spadWidthBytes+1)) total_bytes_sent / spadWidthBytes.U else 0.U) + // "if" condition and the "else" clause. Similarly, the width expansions are also there to satisfy the Verilator + // linter, despite making the code uglier. + if (total_bytes_sent.getWidth >= log2Up(accWidthBytes + 1)) total_bytes_sent / accWidthBytes.U(total_bytes_sent.getWidth.W) else 0.U, + if (total_bytes_sent.getWidth >= log2Up(spadWidthBytes + 1)) total_bytes_sent / spadWidthBytes.U(total_bytes_sent.getWidth.W) else 0.U) } io.out.bits.is_acc := req.bits.is_acc diff --git a/src/main/scala/gemmini/Configs.scala b/src/main/scala/gemmini/Configs.scala index 7464dc61..bd84b317 100644 --- a/src/main/scala/gemmini/Configs.scala +++ b/src/main/scala/gemmini/Configs.scala @@ -146,7 +146,7 @@ object GemminiConfigs { Mux(overflow, sat, rec_fn_to_in.io.out.asTypeOf(t)) }, - 1, Float(8, 24), -1, + 8, Float(8, 24), -1, identity = "1.0", c_str = "({float y = ROUND_NEAR_EVEN((x) * (scale)); y > INT8_MAX ? INT8_MAX : (y < INT8_MIN ? INT8_MIN : (acc_t)y);})" )), @@ -254,13 +254,11 @@ class DefaultGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( gemmini } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) // This Gemmini config has both an Int and an FP Gemmini side-by-side, sharing // the same scratchpad. class DualGemminiConfig extends Config((site, here, up) => { - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) case BuildRoCC => { var int_gemmini: Gemmini[_,_,_] = null var fp_gemmini: Gemmini[_,_,_] = null diff --git a/src/main/scala/gemmini/ConfigsFP.scala b/src/main/scala/gemmini/ConfigsFP.scala index 740ece36..c76907dd 100644 --- a/src/main/scala/gemmini/ConfigsFP.scala +++ b/src/main/scala/gemmini/ConfigsFP.scala @@ -121,7 +121,6 @@ class GemminiFP32DefaultConfig extends Config((site, here, up) => { LazyModule(new Gemmini(GemminiFPConfigs.FP32DefaultConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) @@ -134,7 +133,6 @@ class GemminiFP16DefaultConfig extends Config((site, here, up) => { LazyModule(new Gemmini(GemminiFPConfigs.FP16DefaultConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========BFLOAT16 Default Config========= @@ -146,7 +144,6 @@ class GemminiBF16DefaultConfig extends Config((site, here, up) => { LazyModule(new Gemmini(GemminiFPConfigs.BF16DefaultConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) class GemminiBF16DefaultHighPerfConfig extends Config((site, here, up) => { @@ -161,7 +158,6 @@ class GemminiBF16DefaultHighPerfConfig extends Config((site, here, up) => { gemmini } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========BFLOAT16 Default Config 8x8========= @@ -173,6 +169,5 @@ class GemminiBF16Default8Config extends Config((site, here, up) => { LazyModule(new Gemmini(GemminiFPConfigs.BF16Default8Config)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index 3ff10955..2c15d3ea 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -150,7 +150,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] new ConfigExRs1(acc_scale_t_bits), new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), - has_training_convs, has_max_pool, has_first_layer_optimizations) } + has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs) } val (loop_cmd, loop_matmul_unroller_busy) = withClock (gated_clock) { LoopMatmul(conv_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed, meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, @@ -399,7 +399,6 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] unrolled_cmd.ready := true.B } } - } // Debugging signals diff --git a/src/main/scala/gemmini/CounterFile.scala b/src/main/scala/gemmini/CounterFile.scala index 35f50c20..7b28b8e2 100644 --- a/src/main/scala/gemmini/CounterFile.scala +++ b/src/main/scala/gemmini/CounterFile.scala @@ -225,8 +225,8 @@ class CounterController(nPerfCounter: Int, counterWidth: Int)(implicit p: Parame if (nPerfCounter > 0) { val nCounterIndexBit = log2Ceil(nPerfCounter) - val module = Module(new CounterFile(nPerfCounter: Int, counterWidth: Int)) - module.io.event_io <> io.event_io + val counterfile = Module(new CounterFile(nPerfCounter: Int, counterWidth: Int)) + counterfile.io.event_io <> io.event_io val out_reg = Reg(io.out.bits.cloneType) val out_valid_reg = RegInit(false.B) @@ -242,13 +242,13 @@ class CounterController(nPerfCounter: Int, counterWidth: Int)(implicit p: Parame // rs1[31] = External counter flag io.in.ready := !out_valid_reg - module.io.addr := io.in.bits.rs1(nCounterIndexBit + 3, 4) - module.io.counter_reset := io.in.bits.rs1(0) & io.in.fire - module.io.snapshot_reset := io.in.bits.rs1(1) & io.in.fire - module.io.snapshot := io.in.bits.rs1(2) & io.in.fire - module.io.config_address.valid := io.in.bits.rs1(3) & io.in.fire - module.io.config_address.bits := io.in.bits.rs1(17, 12) - module.io.external := io.in.bits.rs1(31) + counterfile.io.addr := io.in.bits.rs1(nCounterIndexBit + 3, 4) + counterfile.io.counter_reset := io.in.bits.rs1(0) & io.in.fire + counterfile.io.snapshot_reset := io.in.bits.rs1(1) & io.in.fire + counterfile.io.snapshot := io.in.bits.rs1(2) & io.in.fire + counterfile.io.config_address.valid := io.in.bits.rs1(3) & io.in.fire + counterfile.io.config_address.bits := io.in.bits.rs1(17, 12) + counterfile.io.external := io.in.bits.rs1(31) when (io.out.fire) { out_valid_reg := false.B @@ -256,7 +256,7 @@ class CounterController(nPerfCounter: Int, counterWidth: Int)(implicit p: Parame out_valid_reg := true.B out_reg.rd := io.in.bits.inst.rd out_reg.data := 0.U - out_reg.data := module.io.data + out_reg.data := counterfile.io.data } io.out.valid := out_valid_reg diff --git a/src/main/scala/gemmini/CustomConfigs.scala b/src/main/scala/gemmini/CustomConfigs.scala index e1ed7199..ae529a69 100644 --- a/src/main/scala/gemmini/CustomConfigs.scala +++ b/src/main/scala/gemmini/CustomConfigs.scala @@ -41,6 +41,14 @@ object GemminiCustomConfigs { acc_capacity = CapacityInKilobytes(128), ) + val ibertInferenceConfig = defaultConfig.copy( + has_training_convs = false, + has_max_pool = false, + has_normalizations = true, + + acc_capacity = CapacityInKilobytes(128), + ) + // Specify which of your custom configs you want to build here val customConfig = baselineInferenceConfig } @@ -56,5 +64,5 @@ class GemminiCustomConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( gemmini } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) + diff --git a/src/main/scala/gemmini/CustomSoCConfigs.scala b/src/main/scala/gemmini/CustomSoCConfigs.scala index aebfb520..057aa1e1 100644 --- a/src/main/scala/gemmini/CustomSoCConfigs.scala +++ b/src/main/scala/gemmini/CustomSoCConfigs.scala @@ -10,10 +10,10 @@ class CustomGemminiSoCConfig extends Config( new chipyard.config.WithL2TLBs(512) ++ new freechips.rocketchip.subsystem.WithInclusiveCache( - nBanks = 1, nWays = 8, capacityKB = 512, - outerLatencyCycles = 40 + outerLatencyCycles = 40, + subBankingFactor = 4 ) ++ // Set the number of CPUs you want to create diff --git a/src/main/scala/gemmini/DMA.scala b/src/main/scala/gemmini/DMA.scala index 9761228f..71148b67 100644 --- a/src/main/scala/gemmini/DMA.scala +++ b/src/main/scala/gemmini/DMA.scala @@ -1,3 +1,4 @@ + package gemmini import chisel3._ @@ -120,10 +121,11 @@ class StreamReadBeat (val nXacts: Int, val beatBits: Int, val maxReqBytes: Int) } // TODO StreamReaderCore and StreamWriter are actually very alike. Is there some parent class they could both inherit from? -class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], nXacts: Int, beatBits: Int, maxBytes: Int, - spadWidth: Int, accWidth: Int, aligned_to: Int, - spad_rows: Int, acc_rows: Int, meshRows: Int, use_tlb_register_filter: Boolean, - use_firesim_simulation_counters: Boolean) +class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], nXacts: Int, beatBits: Int, + maxBytes: Int, spadWidth: Int, accWidth: Int, aligned_to: Int, + spad_rows: Int, acc_rows: Int, meshRows: Int, + use_tlb_register_filter: Boolean, + use_firesim_simulation_counters: Boolean) (implicit p: Parameters) extends LazyModule { val node = TLHelper.makeClientNode( name = "stream-reader", sourceId = IdRange(0, nXacts)) @@ -263,9 +265,10 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf io.reserve.entry.addr := req.spaddr + req.block_stride * Mux(req.has_acc_bitwidth, // We only add "if" statements here to satisfy the Verilator linter. The code would be cleaner without the - // "if" condition and the "else" clause - if (bytesRequested.getWidth >= log2Up(accWidthBytes+1)) bytesRequested / accWidthBytes.U else 0.U, - if (bytesRequested.getWidth >= log2Up(spadWidthBytes+1)) bytesRequested / spadWidthBytes.U else 0.U) + // "if" condition and the "else" clause. Similarly, the width expansions are also there to satisfy the Verilator + // linter, despite making the code uglier. + if (bytesRequested.getWidth >= log2Up(accWidthBytes+1)) bytesRequested / accWidthBytes.U(bytesRequested.getWidth.W) else 0.U, + if (bytesRequested.getWidth >= log2Up(spadWidthBytes+1)) bytesRequested / spadWidthBytes.U(bytesRequested.getWidth.W) else 0.U) io.reserve.entry.spad_row_offset := Mux(req.has_acc_bitwidth, bytesRequested % accWidthBytes.U, bytesRequested % spadWidthBytes.U) when (untranslated_a.fire) { @@ -408,7 +411,7 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val bytes_written = UInt(log2Up(maxBytes+1).W) val bytes_written_per_beat = Vec(maxBeatsPerReq, UInt(log2Up(beatBytes+1).W)) - def total_beats(dummy: Int = 0) = Mux(size < beatBytes.U, 1.U, size / beatBytes.U) + def total_beats(dummy: Int = 0) = Mux(size < beatBytes.U, 1.U, size / beatBytes.U(size.getWidth.W)) // The width expansion is added here solely to satsify Verilator's linter } val smallest_write_size = aligned_to max beatBytes @@ -460,9 +463,6 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: } val write_packet = RegEnableThru(best_write_packet, state === s_writing_new_block) - for (wp <- write_packets) - dontTouch(wp) - val write_size = write_packet.size val lg_write_size = write_packet.lg_size val write_beats = write_packet.total_beats() diff --git a/src/main/scala/gemmini/DMACommandTracker.scala b/src/main/scala/gemmini/DMACommandTracker.scala index 3390cbdf..9d4f71e6 100644 --- a/src/main/scala/gemmini/DMACommandTracker.scala +++ b/src/main/scala/gemmini/DMACommandTracker.scala @@ -20,7 +20,6 @@ class DMACommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => val tag = Input(tag_t.cloneType) val bytes_to_read = Input(UInt(log2Up(maxBytes+1).W)) val cmd_id = Output(cmd_id_t.cloneType) - } val bits = new BitsT(tag_t.cloneType, cmd_id_t.cloneType) diff --git a/src/main/scala/gemmini/DSEConfigs.scala b/src/main/scala/gemmini/DSEConfigs.scala index a34658e3..3ed92c7c 100644 --- a/src/main/scala/gemmini/DSEConfigs.scala +++ b/src/main/scala/gemmini/DSEConfigs.scala @@ -119,7 +119,6 @@ class GemminiParamsDSE1 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.baseConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========DATAFLOW CHANGE: WS========= @@ -131,7 +130,6 @@ class GemminiParamsDSE2 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.wsOnlyConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========DATAFLOW CHANGE: BOTH========= @@ -143,7 +141,6 @@ class GemminiParamsDSE3 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.bothDataflowsConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========BITWIDTH CHANGE: 32 BITS========= @@ -155,7 +152,6 @@ class GemminiParamsDSE4 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.highBitwidthConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========DIMENSIONS CHANGE: 32x32========= @@ -167,7 +163,6 @@ class GemminiParamsDSE5 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.largerDimConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========PIPELINE DEPTH CHANGE: Fully Combinational========= @@ -179,7 +174,6 @@ class GemminiParamsDSE6 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.fullyCombinationalConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========MEMORY CAPACITY CHANGE: 256 KB========= @@ -191,7 +185,6 @@ class GemminiParamsDSE7 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.moreMemoryConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========MEMORY BANKS CHANGE: 33 Banks========= @@ -203,7 +196,6 @@ class GemminiParamsDSE8 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.moreBanksConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========BUS WIDTH CHANGE: 64 bits========= @@ -215,7 +207,6 @@ class GemminiParamsDSE10 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.narrowerBusConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 8) }) //===========PnR 16-by-16========= @@ -227,7 +218,6 @@ class GemminiParamsPnR16 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.pnr16Config)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========PnR 32-by-32========= @@ -239,7 +229,6 @@ class GemminiParamsPnR32 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.pnr32Config)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) //===========Scalar Processor Change========= @@ -251,7 +240,6 @@ class GemminiParamsDSE11 extends Config((site, here, up) => { LazyModule(new Gemmini(DSEConfigs.baseConfig)) } ) - case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) // ----------------------------- diff --git a/src/main/scala/gemmini/ExecuteController.scala b/src/main/scala/gemmini/ExecuteController.scala index 62fc4495..2ef7fa3f 100644 --- a/src/main/scala/gemmini/ExecuteController.scala +++ b/src/main/scala/gemmini/ExecuteController.scala @@ -29,7 +29,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val acc = new Bundle { val read_req = Vec(acc_banks, Decoupled(new AccumulatorReadReq( - acc_bank_entries, log2Up(accType.getWidth), acc_scale_t + acc_bank_entries, accType, acc_scale_t ))) val read_resp = Flipped(Vec(acc_banks, Decoupled(new AccumulatorScaleResp( @@ -115,8 +115,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val in_shift = Reg(UInt(log2Up(accType.getWidth).W)) val acc_scale = Reg(acc_scale_t) - val relu6_shift = Reg(UInt(log2Up(accType.getWidth).W)) - val activation = if (has_nonlinear_activations) Reg(UInt(2.W)) else Activation.NONE // TODO magic number + val activation = if (has_nonlinear_activations) Reg(UInt(Activation.bitwidth.W)) else Activation.NONE // TODO magic number val a_transpose = Reg(Bool()) val bd_transpose = Reg(Bool()) val config_initialized = RegInit(false.B) @@ -470,7 +469,10 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In io.acc.read_req(i).valid := read_a_from_acc || read_b_from_acc || read_d_from_acc io.acc.read_req(i).bits.scale := acc_scale io.acc.read_req(i).bits.full := false.B - io.acc.read_req(i).bits.relu6_shift := relu6_shift + io.acc.read_req(i).bits.igelu_qb := DontCare + io.acc.read_req(i).bits.igelu_qc := DontCare + io.acc.read_req(i).bits.iexp_qln2 := DontCare + io.acc.read_req(i).bits.iexp_qln2_inv := DontCare io.acc.read_req(i).bits.act := activation io.acc.read_req(i).bits.fromDMA := false.B io.acc.read_req(i).bits.addr := MuxCase(a_address_rs1.acc_row() + a_fire_counter, @@ -487,7 +489,10 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In io.acc.read_req(i).valid := false.B io.acc.read_req(i).bits.scale := DontCare io.acc.read_req(i).bits.full := false.B - io.acc.read_req(i).bits.relu6_shift := relu6_shift + io.acc.read_req(i).bits.igelu_qb := DontCare + io.acc.read_req(i).bits.igelu_qc := DontCare + io.acc.read_req(i).bits.iexp_qln2 := DontCare + io.acc.read_req(i).bits.iexp_qln2_inv := DontCare io.acc.read_req(i).bits.act := DontCare io.acc.read_req(i).bits.fromDMA := false.B io.acc.read_req(i).bits.addr := DontCare @@ -503,8 +508,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In when (read_a && !io.im2col.req.ready) { a_ready := false.B } - dontTouch(io.im2col.req.ready) - dontTouch(read_a) io.im2col.req.valid := read_a io.im2col.req.bits.addr := a_address_rs1 @@ -550,7 +553,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } in_shift := config_ex_rs2.in_shift acc_scale := rs1s(0)(xLen - 1, 32).asTypeOf(acc_scale_t) // TODO magic number - relu6_shift := config_ex_rs2.relu6_shift a_transpose := config_ex_rs1.a_transpose bd_transpose := config_ex_rs1.b_transpose @@ -614,7 +616,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In start_inputting_a := !a_should_be_fed_into_transposer start_inputting_b := !b_should_be_fed_into_transposer - start_inputting_b := true.B control_state := compute } @@ -924,8 +925,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val activated_wdata = VecInit(mesh.io.resp.bits.data.map(v => VecInit(v.map { e => val e_clipped = e.clippedToWidthOf(inputType) val e_act = MuxCase(e_clipped, Seq( - (activation === Activation.RELU) -> e_clipped.relu, - (activation === Activation.RELU6) -> e_clipped.relu6(relu6_shift))) + (activation === Activation.RELU) -> e_clipped.relu)) e_act }))) @@ -992,7 +992,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In when(io.completed.valid) { complete_bits_count := complete_bits_count + 1.U } - dontTouch(complete_bits_count) when (reset.asBool()) { // pending_completed_rob_id.valid := false.B diff --git a/src/main/scala/gemmini/GemminiConfigs.scala b/src/main/scala/gemmini/GemminiConfigs.scala index 97e068c8..573581ec 100644 --- a/src/main/scala/gemmini/GemminiConfigs.scala +++ b/src/main/scala/gemmini/GemminiConfigs.scala @@ -85,7 +85,8 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( has_training_convs: Boolean = true, has_max_pool: Boolean = true, has_nonlinear_activations: Boolean = true, - + has_dw_convs: Boolean = true, + has_normalizations: Boolean = false, has_first_layer_optimizations: Boolean = true, use_firesim_simulation_counters: Boolean = false, @@ -491,6 +492,11 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( header ++= "#define HAS_FIRST_LAYER_OPTIMIZATIONS\n\n" } + if (has_normalizations) { + header ++= "#define HAS_NORMALIZATIONS\n" + header ++= "#define NORM_STAT_IDS 4\n\n" + } + header ++= s"#endif // $guard\n" header.toString() } diff --git a/src/main/scala/gemmini/GemminiISA.scala b/src/main/scala/gemmini/GemminiISA.scala index 9cb15ac9..7bca089b 100644 --- a/src/main/scala/gemmini/GemminiISA.scala +++ b/src/main/scala/gemmini/GemminiISA.scala @@ -24,7 +24,7 @@ object GemminiISA { val LOAD3_CMD = 14.U // TODO add orows and ocols to this as well - val LOOP_CONV_WS = 15.U // no_bias, wrot180, trans_output_1203, trans_weight_1203, trans_input_3120, max_pixels_per_row | no_pool, downsample, input_dilated, act + val LOOP_CONV_WS = 15.U // no_bias, wrot180, trans_output_1203, trans_weight_1203, trans_input_3120, dw, max_pixels_per_row | no_pool, downsample, input_dilated, act val LOOP_CONV_WS_CONFIG_1 = 16.U // batch_size, in_dim, in_channels, out_channels | out_dim, pool_out_dim, stride, padding val LOOP_CONV_WS_CONFIG_2 = 17.U // kernel_dim, pool_size, pool_stride, pool_padding | batches, porows, pocols, pochs val LOOP_CONV_WS_CONFIG_3 = 18.U // krows, kcols, kchs, lpad | rpad, upad, dpad, plpad @@ -38,7 +38,7 @@ object GemminiISA { val CONFIG_EX = 0.U val CONFIG_LOAD = 1.U val CONFIG_STORE = 2.U - val CONFIG_IM2COL = 3.U + val CONFIG_NORM = 3.U //========================================================================== // cisc-gemmini opcodes @@ -107,7 +107,7 @@ object GemminiISA { val _unused = UInt(CONFIG_MVIN_RS1_UNUSED_WIDTH.W) } - val CONFIG_MVOUT_RS1_UNUSED_WIDTH = 2 + val CONFIG_MVOUT_RS1_CMD_TYPE_WIDTH = 2 val CONFIG_MVOUT_RS1_ACTIVATION_WIDTH = 2 val CONFIG_MVOUT_RS1_MAX_POOLING_STRIDE_WIDTH = 2 val CONFIG_MVOUT_RS1_MAX_POOLING_WINDOW_SIZE_WIDTH = 2 @@ -132,7 +132,7 @@ object GemminiISA { val pool_size = UInt(CONFIG_MVOUT_RS1_MAX_POOLING_WINDOW_SIZE_WIDTH.W) val pool_stride = UInt(CONFIG_MVOUT_RS1_MAX_POOLING_STRIDE_WIDTH.W) val activation = UInt(CONFIG_MVOUT_RS1_ACTIVATION_WIDTH.W) - val _unused = UInt(CONFIG_MVOUT_RS1_UNUSED_WIDTH.W) + val cmd_type = UInt(CONFIG_MVOUT_RS1_CMD_TYPE_WIDTH.W) } val CONFIG_MVOUT_RS2_ACC_SCALE_WIDTH = 32 @@ -145,6 +145,36 @@ object GemminiISA { val stride = UInt(stride_bits.W) } + val CONFIG_NORM_RS1_Q_CONST_WIDTH = 32 + val CONFIG_NORM_RS1_SPACER1_WIDTH = 13 + val CONFIG_NORM_RS1_Q_CONST_TYPE_WIDTH = 1 + val CONFIG_NORM_RS1_SET_STATS_ID_ONLY_WIDTH = 1 + val CONFIG_NORM_RS1_ACT_MSB_WIDTH = 1 + val CONFIG_NORM_RS1_NORM_STATS_ID_WIDTH = 8 + val CONFIG_NORM_RS1_SPACER0_WIDTH = 6 + val CONFIG_NORM_RS1_CMD_TYPE_WIDTH = 2 + + class ConfigNormRs1(acc_t_bits: Int = 32) extends Bundle { + val q_const = UInt(acc_t_bits.W) + val _spacer1 = UInt(CONFIG_NORM_RS1_SPACER1_WIDTH.W) + val q_const_type = UInt(CONFIG_NORM_RS1_Q_CONST_TYPE_WIDTH.W) + val set_stats_id_only = UInt(CONFIG_NORM_RS1_SET_STATS_ID_ONLY_WIDTH.W) + val act_msb = UInt(CONFIG_NORM_RS1_ACT_MSB_WIDTH.W) + val norm_stats_id = UInt(CONFIG_NORM_RS1_NORM_STATS_ID_WIDTH.W) + val _spacer0 = UInt(CONFIG_NORM_RS1_SPACER0_WIDTH.W) + val cmd_type = UInt(CONFIG_NORM_RS1_CMD_TYPE_WIDTH.W) + } + + val CONFIG_NORM_RS2_QC_WIDTH = 32 + val CONFIG_NORM_RS2_QB_WIDTH = 32 + + class ConfigNormRs2(acc_t_bits: Int) extends Bundle { + val _spacer1 = UInt((CONFIG_NORM_RS2_QC_WIDTH - acc_t_bits).W) + val qc = UInt(acc_t_bits.W) + val _spacer0 = UInt((CONFIG_NORM_RS2_QB_WIDTH - acc_t_bits).W) + val qb = UInt(acc_t_bits.W) + } + val CONFIG_EX_RS1_CMD_TYPE_WIDTH = 2 val CONFIG_EX_RS1_DATAFLOW_WIDTH = 1 val CONFIG_EX_RS1_ACTIVATION_WIDTH = 2 diff --git a/src/main/scala/gemmini/LoadController.scala b/src/main/scala/gemmini/LoadController.scala index db69857a..71ecf7c7 100644 --- a/src/main/scala/gemmini/LoadController.scala +++ b/src/main/scala/gemmini/LoadController.scala @@ -114,7 +114,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig cmd_tracker.io.alloc.valid := control_state === waiting_for_command && cmd.valid && DoLoad cmd_tracker.io.alloc.bits.bytes_to_read := Mux(io.dma.req.bits.has_acc_bitwidth, cols * actual_rows_read * config.accType.getWidth.U, - cols * actual_rows_read * config.inputType.getWidth.U) / 8.U + cols * actual_rows_read * config.inputType.getWidth.U) >> 3 // We replaced a very clear "/ 8.U" operation here with a ">> 3" operation, solely to satisfy Verilator's linter cmd_tracker.io.alloc.bits.tag.rob_id := cmd.bits.rob_id.bits cmd_tracker.io.request_returned.valid := io.dma.resp.fire // TODO use a bundle connect cmd_tracker.io.request_returned.bits.cmd_id := io.dma.resp.bits.cmd_id // TODO use a bundle connect diff --git a/src/main/scala/gemmini/LocalAddr.scala b/src/main/scala/gemmini/LocalAddr.scala index 92e46ffc..b53addea 100644 --- a/src/main/scala/gemmini/LocalAddr.scala +++ b/src/main/scala/gemmini/LocalAddr.scala @@ -21,8 +21,13 @@ class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_en val is_acc_addr = Bool() val accumulate = Bool() val read_full_acc_row = Bool() - val garbage = UInt(((localAddrBits - maxAddrBits - 4) max 0).W) - val garbage_bit = if (localAddrBits - maxAddrBits >= 4) UInt(1.W) else UInt(0.W) + val norm_cmd = NormCmd() + + private val metadata_w = is_acc_addr.getWidth + accumulate.getWidth + read_full_acc_row.getWidth + norm_cmd.getWidth + assert(maxAddrBits + metadata_w < 32) + + val garbage = UInt(((localAddrBits - maxAddrBits - metadata_w - 1) max 0).W) + val garbage_bit = if (localAddrBits - maxAddrBits >= metadata_w + 1) UInt(1.W) else UInt(0.W) val data = UInt(maxAddrBits.W) def sp_bank(dummy: Int = 0) = if (spAddrBits == spBankRowBits) 0.U else data(spAddrBits - 1, spBankRowBits) diff --git a/src/main/scala/gemmini/LoopConv.scala b/src/main/scala/gemmini/LoopConv.scala index 16609f5a..210bcade 100644 --- a/src/main/scala/gemmini/LoopConv.scala +++ b/src/main/scala/gemmini/LoopConv.scala @@ -1,3 +1,4 @@ + package gemmini import chisel3._ @@ -115,7 +116,7 @@ class LoopConvLdBias(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwi // Addresses val dram_offset = och * (acc_w/8).U val dram_addr = Mux(req.no_bias, 0.U, req.dram_addr + LoopConv.castDramOffset(dram_offset)) - val spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + val spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol // Sizes val I = Mux(ocols - ocol > block_size.U, block_size.U, ocols - ocol) @@ -225,9 +226,10 @@ class LoopConvLdInputReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: val loop_id = UInt(log2Up(concurrent_loops).W) } -class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_addr: Int, input_w: Int, - max_block_len: Int, concurrent_loops: Int, latency: Int, - config_mvin_rs1_t: ConfigMvinRs1, mvin_rs2_t: MvinRs2)(implicit p: Parameters) extends Module { +class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, + tiny_iterator_bitwidth: Int, max_addr: Int, input_w: Int, max_block_len: Int, + concurrent_loops: Int, latency: Int, config_mvin_rs1_t: ConfigMvinRs1, mvin_rs2_t: MvinRs2) + (implicit p: Parameters) extends Module { val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow val io = IO(new Bundle { @@ -397,12 +399,14 @@ class LoopConvLdWeightReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: val dram_addr = UInt(coreMaxAddrBits.W) val trans_weight_1203 = Bool() val trans_weight_0132 = Bool() + val dw = Bool() val loop_id = UInt(log2Up(concurrent_loops).W) } -class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_addr: Int, input_w: Int, - max_block_len: Int, concurrent_loops: Int, latency: Int, - config_mvin_rs1_t: ConfigMvinRs1, mvin_rs2_t: MvinRs2)(implicit p: Parameters) extends Module { +class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, + small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_addr: Int, input_w: Int, + max_block_len: Int, concurrent_loops: Int, latency: Int, config_mvin_rs1_t: ConfigMvinRs1, + mvin_rs2_t: MvinRs2)(implicit p: Parameters) extends Module { val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow val io = IO(new Bundle { @@ -439,6 +443,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit val addr_start = req.addr_end - B_rows val dram_stride = MuxCase(out_channels, Seq( + req.dw -> 1.U, req.trans_weight_1203 -> (kernel_dim * kernel_dim * out_channels), req.trans_weight_0132 -> in_channels )) * (input_w/8).U @@ -451,14 +456,16 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit // Addresses val dram_offset = MuxCase(((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * out_channels +& och) * (input_w/8).U, Seq( + req.dw -> (krow * kernel_dim +& kcol) * (input_w/8).U, req.trans_weight_1203 -> (((kch*kernel_dim*kernel_dim +& krow*kernel_dim +& kcol) * out_channels +& och) * (input_w/8).U), req.trans_weight_0132 -> (((krow*kernel_dim*out_channels +& kcol*out_channels +& och) * in_channels +& kch) * (input_w/8).U) )) val dram_addr = req.dram_addr + LoopConv.castDramOffset(dram_offset) val spad_addr = Mux(req.trans_weight_0132, - addr_start + (kch / block_size.U) * krows * kcols * ochs + krow * kcols * ochs + kcol * ochs + och, - addr_start + (och / block_size.U) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch) + // The width expansions are added here solely to prevent Verilator's "WIDTH" warnings, despite making the code uglier + addr_start + (kch / block_size.U(kch.getWidth.W)) * krows * kcols * ochs + krow * kcols * ochs + kcol * ochs + och, + addr_start + (och / block_size.U(och.getWidth.W)) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch) // Sizes val J = Mux(req.trans_weight_0132, @@ -643,13 +650,14 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera // Addresses val a_addr = Mux(req.trans_input_3120, a_addr_start +& (b / block_size.U) * input_spad_stride +& kch * (irows >> req.downsample) * (icols >> req.downsample) +& (irow >> req.downsample) * (icols >> req.downsample) +& (icol >> req.downsample), - a_addr_start +& (kch / block_size.U) * input_spad_stride +& b * (irows >> req.downsample) * (icols >> req.downsample) +& (irow >> req.downsample) * (icols >> req.downsample) +& (icol >> req.downsample)) + a_addr_start +& (kch / block_size.U(kch.getWidth.W)) * input_spad_stride +& b * (irows >> req.downsample) * (icols >> req.downsample) +& (irow >> req.downsample) * (icols >> req.downsample) +& (icol >> req.downsample)) // val c_addr = Mux(ex_overwrite && krow === 0.U && kcol === 0.U && kch === 0.U, d_addr_start, c_addr_start) +& // (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + // The width expansions are added here solely to prevent Verilator's "WIDTH" warnings, despite making the code uglier val c_addr = c_addr_start +& - (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol // val new_weights = b === 0.U && orow === 0.U && ocol === 0.U val new_weights = Reg(Bool()) @@ -657,8 +665,8 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera val kcol_ = Mux(req.wrot180, kcols - kcol - 1.U, kcol) val b_addr = Mux(req.trans_weight_0132, - b_addr_start +& (kch / block_size.U) * krows * kcols * ochs +& krow_ * kcols * ochs +& kcol_ * ochs +& och, - b_addr_start +& (och / block_size.U) * krows * kcols * kchs +& krow_ * kcols * kchs +& kcol_ * kchs +& kch) + b_addr_start +& (kch / block_size.U(och.getWidth.W)) * krows * kcols * ochs +& krow_ * kcols * ochs +& kcol_ * ochs +& och, + b_addr_start +& (och / block_size.U(och.getWidth.W)) * krows * kcols * kchs +& krow_ * kcols * kchs +& kcol_ * kchs +& kch) class RoCCCommandWithAddr extends Bundle { val cmd = new RoCCCommand @@ -874,10 +882,10 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: ((orow*out_dim*batch_size +& ocol*batch_size +& b) * out_channels +& och) * (input_w/8).U, ((b*out_dim*out_dim +& orow*out_dim +& ocol) * out_channels +& och) * (input_w/8).U) val dram_addr = req.dram_addr + LoopConv.castDramOffset(dram_offset) - val spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + val spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol val pool_dram_addr = req.dram_addr + ((b * pool_out_dim * pool_out_dim) * out_channels + och) * (input_w/8).U - val pool_spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols + val pool_spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols // Sizes val I = Mux(ocols - ocol > block_size.U, block_size.U, ocols - ocol) @@ -919,7 +927,7 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: pre_pool_config_cmd_rs1.pool_size := pool_size pre_pool_config_cmd_rs1.pool_stride := pool_stride pre_pool_config_cmd_rs1.activation := req.activation - pre_pool_config_cmd_rs1._unused := CONFIG_STORE + pre_pool_config_cmd_rs1.cmd_type := CONFIG_STORE pre_pool_config_cmd.rs1 := pre_pool_config_cmd_rs1.asUInt() val pre_pool_config_cmd_rs2 = Wire(config_mvout_rs2_t.cloneType) @@ -935,7 +943,7 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: val post_pool_config_cmd_rs1 = Wire(new ConfigMvoutRs1) post_pool_config_cmd_rs1 := DontCare post_pool_config_cmd_rs1.activation := req.activation - post_pool_config_cmd_rs1._unused := CONFIG_STORE + post_pool_config_cmd_rs1.cmd_type := CONFIG_STORE post_pool_config_cmd.rs1 := post_pool_config_cmd_rs1.asUInt() val post_pool_config_cmd_rs2 = Wire(config_mvout_rs2_t.cloneType) @@ -1059,6 +1067,7 @@ class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val s val trans_weight_1203 = Bool() val trans_weight_0132 = Bool() val trans_input_3120 = Bool() + val dw = Bool() val max_pixels_per_row = UInt(small_iterator_bitwidth.W) @@ -1112,8 +1121,8 @@ class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val s result.ichs := kchs - result.out_channels_per_bank := result.ochs / block_size.U +& (result.ochs % block_size.U =/= 0.U) - result.in_channels_per_bank := result.ichs / block_size.U +& (result.ichs % block_size.U =/= 0.U) + result.out_channels_per_bank := result.ochs / block_size.U(result.ochs.getWidth.W) +& (result.ochs % block_size.U =/= 0.U) + result.in_channels_per_bank := result.ichs / block_size.U(result.ochs.getWidth.W) +& (result.ichs % block_size.U =/= 0.U) result.bias_spad_stride := batches * orows * ocols result.input_spad_stride := Mux(trans_input_3120, @@ -1150,7 +1159,8 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: config_mvin_rs1_t: ConfigMvinRs1, mvin_rs2_t: MvinRs2, config_mvout_rs2_t: ConfigMvoutRs2, mvout_rs2_t: MvoutRs2, config_ex_rs1_t: ConfigExRs1, preload_rs1_t: PreloadRs, preload_rs2_t: PreloadRs, compute_rs1_t: ComputeRs, compute_rs2_t: ComputeRs, - has_training_convs: Boolean, has_max_pool: Boolean, has_first_layer_optimizations: Boolean) + has_training_convs: Boolean, has_max_pool: Boolean, has_first_layer_optimizations: Boolean, + has_dw_convs: Boolean) (implicit p: Parameters) extends Module { val large_iterator_bitwidth = 16 val small_iterator_bitwidth = 16 // 8 @@ -1330,6 +1340,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: loop_being_configured.trans_weight_1203 := has_training_convs.B && cmd.bits.cmd.rs1(3) loop_being_configured.trans_weight_0132 := has_training_convs.B && cmd.bits.cmd.rs1(4) loop_being_configured.trans_input_3120 := has_training_convs.B && cmd.bits.cmd.rs1(5) + loop_being_configured.dw := has_dw_convs.B && cmd.bits.cmd.rs1(6) loop_being_configured.no_pool := !has_max_pool.B || cmd.bits.cmd.rs2(0) loop_being_configured.activation := cmd.bits.cmd.rs2(4,3) @@ -1400,6 +1411,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: ld_weights.io.req.bits.dram_addr := loop_requesting_ld_weights.weights_dram_addr ld_weights.io.req.bits.trans_weight_1203 := loop_requesting_ld_weights.trans_weight_1203 ld_weights.io.req.bits.trans_weight_0132 := loop_requesting_ld_weights.trans_weight_0132 + ld_weights.io.req.bits.dw := loop_requesting_ld_weights.dw ld_weights.io.req.bits.loop_id := loop_requesting_ld_weights_id ld_weights.io.req.valid := !loop_requesting_ld_weights.ld_weights_started && loop_requesting_ld_weights.configured @@ -1503,13 +1515,13 @@ object LoopConv { config_mvin_rs1_t: ConfigMvinRs1, mvin_rs2_t: MvinRs2, config_mvout_rs2_t: ConfigMvoutRs2, mvout_rs2_t: MvoutRs2, config_ex_rs1_t: ConfigExRs1, preload_rs1_t: PreloadRs, preload_rs2_t: PreloadRs, compute_rs1_t: ComputeRs, compute_rs2_t: ComputeRs, has_training_convs: Boolean, has_max_pool: Boolean, - has_first_layer_optimizations: Boolean) + has_first_layer_optimizations: Boolean, has_dw_convs: Boolean) (implicit p: Parameters): (DecoupledIO[GemminiCmd], Bool) = { val mod = Module(new LoopConv(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts, max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes, config_mvin_rs1_t, mvin_rs2_t, config_mvout_rs2_t, mvout_rs2_t, config_ex_rs1_t, preload_rs1_t, preload_rs2_t, - compute_rs1_t, compute_rs2_t, has_training_convs, has_max_pool, has_first_layer_optimizations)) + compute_rs1_t, compute_rs2_t, has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs)) mod.io.in <> in mod.io.ld_completed := ld_completed diff --git a/src/main/scala/gemmini/LoopMatmul.scala b/src/main/scala/gemmini/LoopMatmul.scala index 52871276..86552d56 100644 --- a/src/main/scala/gemmini/LoopMatmul.scala +++ b/src/main/scala/gemmini/LoopMatmul.scala @@ -1,3 +1,4 @@ + package gemmini import chisel3._ @@ -488,6 +489,7 @@ class LoopMatmulStCReq(val block_size: Int, val coreMaxAddrBits: Int, val iterat val dram_addr = UInt(coreMaxAddrBits.W) val dram_stride = UInt(coreMaxAddrBits.W) val full_c = Bool() + val act = UInt(Activation.bitwidth.W) val addr_start = UInt(log2Up(max_acc_addr).W) val loop_id = UInt(log2Up(concurrent_loops).W) } @@ -513,7 +515,7 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In }) object State extends ChiselEnum { - val idle, st = Value + val idle, st, ln_config, ln_st = Value } import State._ val state = RegInit(idle) @@ -522,6 +524,7 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val max_blocks = Mux(req.full_c, 1.U, Mux(req.max_j <= max_block_len.U, req.max_j, max_block_len.U)) + // Non-normalization-related iterators and calculations val j = Reg(UInt(iterator_bitwidth.W)) val i = Reg(UInt(iterator_bitwidth.W)) @@ -547,26 +550,80 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In mvout_cmd_rs2.local_addr := cast_to_acc_addr(mvout_cmd_rs2.local_addr, sp_addr, accumulate = false.B, read_full = req.full_c) mvout_cmd.rs2 := mvout_cmd_rs2.asUInt() + // Layernorm iterators and calculations + val ln_row = Reg(UInt(iterator_bitwidth.W)) + val ln_cmd = Reg(UInt(iterator_bitwidth.W)) + val ln_stat_id = Reg(UInt(iterator_bitwidth.W)) + + val NORM_STAT_IDS = 4 // TODO magic number + + val ln_norm_cmds = VecInit(VecInit(NormCmd.SUM, NormCmd.MEAN), VecInit(NormCmd.VARIANCE, NormCmd.INV_STDDEV), + VecInit(NormCmd.RESET, NormCmd.RESET)) + + val sm_norm_cmds = VecInit(VecInit(NormCmd.MAX, NormCmd.MAX), VecInit(NormCmd.SUM_EXP, NormCmd.INV_SUM_EXP), + VecInit(NormCmd.RESET, NormCmd.RESET)) + + val ln_stat_ids = Mux(rows -& ln_row > NORM_STAT_IDS.U, NORM_STAT_IDS.U, rows -& ln_row) + + val ln_r = ln_row +& ln_stat_id + + val ln_sp_addr = acc_addr_start +& (i * req.max_j +& j) * block_size.U +& ln_r + val ln_norm_cmd = Mux(j +& max_blocks >= req.max_j, + Mux(req.act === Activation.LAYERNORM, ln_norm_cmds(ln_cmd)(1), sm_norm_cmds(ln_cmd)(1)), + Mux(req.act === Activation.LAYERNORM, ln_norm_cmds(ln_cmd)(0), sm_norm_cmds(ln_cmd)(0))) + + // TODO we assume for now that full_C and layernorm aren't true at the same + val ln_dram_offset = ((i * req.dram_stride +& j) * block_size.U +& ln_r * req.dram_stride) * (input_w/8).U + val ln_dram_addr = req.dram_addr + LoopMatmul.castDramOffset(ln_dram_offset) + + val ln_config_norm_rs1 = Wire(new GemminiISA.ConfigNormRs1) + ln_config_norm_rs1 := DontCare + ln_config_norm_rs1.set_stats_id_only := 1.U + ln_config_norm_rs1.cmd_type := CONFIG_NORM + ln_config_norm_rs1.norm_stats_id := ln_stat_id + + val ln_config_norm = Wire(new RoCCCommand) + ln_config_norm := DontCare + ln_config_norm.inst.funct := CONFIG_CMD + ln_config_norm.rs1 := ln_config_norm_rs1.asUInt() + ln_config_norm.rs2 := DontCare + + val ln_mvout_cmd = Wire(new RoCCCommand) + ln_mvout_cmd := DontCare + ln_mvout_cmd.inst.funct := STORE_CMD + ln_mvout_cmd.rs1 := ln_dram_addr + + val ln_mvout_cmd_rs2 = Wire(mvout_rs2_t.cloneType) + ln_mvout_cmd_rs2 := DontCare + ln_mvout_cmd_rs2.num_rows := 1.U + ln_mvout_cmd_rs2.num_cols := cols.asUInt() + ln_mvout_cmd_rs2.local_addr := cast_to_acc_addr(ln_mvout_cmd_rs2.local_addr, ln_sp_addr, accumulate = false.B, read_full = req.full_c) + ln_mvout_cmd_rs2.local_addr.norm_cmd := ln_norm_cmd + ln_mvout_cmd.rs2 := ln_mvout_cmd_rs2.asUInt() + io.req.ready := state === idle io.j := j io.i := i io.idle := state === idle - // The order here is k, j, i - // val ex_ahead = io.ex_completed || (io.ex_k === req.max_k - 1.U && (io.ex_j > j || (io.ex_j === j && io.ex_i > i))) + // The order here is k, j, i when not doing LAYERNORM or SOFTMAX val ex_ahead = io.ex_completed || - (io.ex_k === req.max_k - 1.U && - (io.ex_j >= j + blocks || - ((io.ex_j === j + blocks - 1.U) && io.ex_i > i))) + ((req.act =/= Activation.LAYERNORM) && (req.act =/= Activation.SOFTMAX) && + (io.ex_k === req.max_k - 1.U && + (io.ex_j >= j + blocks || + ((io.ex_j === j + blocks - 1.U) && io.ex_i > i)))) io.cmd.valid := state =/= idle && !io.rob_overloaded && ex_ahead && req.dram_addr =/= 0.U - io.cmd.bits := mvout_cmd + io.cmd.bits := MuxCase(mvout_cmd, Seq( + (state === ln_config) -> ln_config_norm, + (state === ln_st) -> ln_mvout_cmd, + )) io.loop_id := req.loop_id when (req.dram_addr === 0.U) { state := idle - }.elsewhen (io.cmd.fire) { + }.elsewhen (io.cmd.fire() && state === st) { // The order here is k, j, i val next_i = floorAdd(i, 1.U, req.max_i) val next_j = floorAdd(j, max_blocks, req.max_j, next_i === 0.U) @@ -577,13 +634,38 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In when (next_i === 0.U && next_j === 0.U) { state := idle } + }.elsewhen (io.cmd.fire() && state === ln_config) { + state := ln_st + }.elsewhen (io.cmd.fire() && state === ln_st) { + val next_j = floorAdd(j, max_blocks, req.max_j) + val next_stat_id = floorAdd(ln_stat_id, 1.U, ln_stat_ids, next_j === 0.U) + val next_cmd = floorAdd(ln_cmd, 1.U, ln_norm_cmds.size.U, next_j === 0.U && next_stat_id === 0.U) + val next_row = floorAdd(ln_row, NORM_STAT_IDS.U, rows, next_j === 0.U && next_stat_id === 0.U && next_cmd === 0.U) + val next_i = floorAdd(i, 1.U, req.max_i, + next_j === 0.U && next_stat_id === 0.U && next_cmd === 0.U && next_row === 0.U) + + j := next_j + ln_stat_id := next_stat_id + ln_cmd := next_cmd + ln_row := next_row + i := next_i + + when (next_i === 0.U && next_row === 0.U && next_cmd === 0.U && next_stat_id === 0.U && next_j === 0.U) { + state := idle + }.elsewhen (next_j === 0.U) { + state := ln_config + } } when (io.req.fire) { req := io.req.bits - state := st + state := Mux((io.req.bits.act === Activation.LAYERNORM) || (io.req.bits.act === Activation.SOFTMAX), ln_config, st) + j := 0.U i := 0.U + ln_row := 0.U + ln_cmd := 0.U + ln_stat_id := 0.U } } @@ -610,12 +692,12 @@ class LoopMatmulState(val iterator_bitwidth: Int, val coreMaxAddrBits: Int, val val a_transpose = Bool() val b_transpose = Bool() + val act = UInt(Activation.bitwidth.W) + val low_d = Bool() val full_c = Bool() val ex_accumulate = Bool() - val weightA = UInt(8.W) // TODO magic numbers - val configured = Bool() val running = Bool() @@ -706,7 +788,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size val ab_loads_on_same_loop = ldA.io.loop_id === ldB.io.loop_id ldab_arb.io.forceA := !ab_loads_on_same_loop && ldA.io.loop_id === head_loop_id ldab_arb.io.forceB := !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id - ldab_arb.io.weightA := head_loop.weightA + ldab_arb.io.weightA := 0.U ldab_arb.io.inA_idle := ldA.io.idle ldab_arb.io.inB_idle := ldB.io.idle ldab_arb.io.inA_k := ldA.io.k @@ -812,11 +894,11 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size loop_being_configured.ex_accumulate := cmd.bits.cmd.rs1(0) loop_being_configured.full_c := cmd.bits.cmd.rs1(1) loop_being_configured.low_d := cmd.bits.cmd.rs1(2) + loop_being_configured.act := cmd.bits.cmd.rs1(8+Activation.bitwidth-1, 8) // TODO magic numbers + loop_being_configured.a_transpose := cmd.bits.cmd.rs2(0) loop_being_configured.b_transpose := cmd.bits.cmd.rs2(1) - loop_being_configured.weightA := cmd.bits.cmd.rs1(15, 8) // TODO magic numbers - loop_being_configured.configured := true.B loops_configured := loops_configured + 1.U @@ -928,6 +1010,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size stC.io.req.bits.dram_addr := loop_requesting_st.c_dram_addr stC.io.req.bits.dram_stride := loop_requesting_st.c_dram_stride stC.io.req.bits.full_c := loop_requesting_st.full_c + stC.io.req.bits.act := loop_requesting_st.act stC.io.req.bits.addr_start := st_c_addr_start stC.io.req.bits.loop_id := loop_requesting_st_id diff --git a/src/main/scala/gemmini/MeshWithDelays.scala b/src/main/scala/gemmini/MeshWithDelays.scala index f6cf7517..d0aced16 100644 --- a/src/main/scala/gemmini/MeshWithDelays.scala +++ b/src/main/scala/gemmini/MeshWithDelays.scala @@ -232,8 +232,6 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] val out_matmul_id = WireInit(shifted(mesh.io.out_id, outBanks, reverse = true)(0)(0)) io.resp.bits.tag := Mux(tagq.io.deq.valid && out_matmul_id === tagq.io.deq.bits.id, tagq.io.deq.bits.tag, tag_garbage) - dontTouch(out_matmul_id) - tagq.io.deq.ready := io.resp.valid && io.resp.bits.last && out_matmul_id === tagq.io.deq.bits.id val total_rows_q = Module(new Queue(new TagWithIdAndTotalRows, tagqlen)) diff --git a/src/main/scala/gemmini/NormCmd.scala b/src/main/scala/gemmini/NormCmd.scala new file mode 100644 index 00000000..515fabb0 --- /dev/null +++ b/src/main/scala/gemmini/NormCmd.scala @@ -0,0 +1,23 @@ + +package gemmini + +import chisel3._ +import chisel3.util._ +import chisel3.experimental.ChiselEnum + +object NormCmd extends ChiselEnum { + val RESET, SUM, MEAN, VARIANCE, INV_STDDEV, MAX, SUM_EXP, INV_SUM_EXP = Value + + def writes_to_main_memory(cmd: Type): Bool = { + cmd === RESET + } + + def non_reset_version(cmd: Type): Type = { + MuxCase(cmd, Seq( + (cmd === MEAN) -> SUM, + (cmd === MAX) -> MAX, + (cmd === INV_STDDEV) -> VARIANCE, + (cmd === INV_SUM_EXP) -> SUM_EXP + )) + } +} diff --git a/src/main/scala/gemmini/Normalizer.scala b/src/main/scala/gemmini/Normalizer.scala new file mode 100644 index 00000000..89dca2db --- /dev/null +++ b/src/main/scala/gemmini/Normalizer.scala @@ -0,0 +1,635 @@ + +package gemmini + +import chisel3._ +import chisel3.experimental.ChiselEnum +import chisel3.util._ +import gemmini.AccumulatorScale.iexp +import hardfloat.{DivSqrtRecFN_small, INToRecFN, consts, fNFromRecFN} + +class NormalizedInput[T <: Data: Arithmetic, U <: Data](max_len: Int, num_stats: Int, fullDataType: Vec[Vec[T]], + scale_t: U) extends Bundle { + val acc_read_resp = new AccumulatorReadResp[T,U](fullDataType, scale_t) + val len = UInt(log2Up(max_len + 1).W) + val stats_id = UInt(log2Up(num_stats).W) + val cmd = NormCmd() +} + +class NormalizedOutput[T <: Data: Arithmetic, U <: Data](fullDataType: Vec[Vec[T]], scale_t: U) extends Bundle { + val acc_read_resp = new AccumulatorReadResp[T,U](fullDataType, scale_t) + val mean = fullDataType.head.head.cloneType + val max = fullDataType.head.head.cloneType + val inv_stddev = scale_t.cloneType + val inv_sum_exp = scale_t.cloneType +} + +class IExpConst[T <: Data](acc_t: T) extends Bundle { + val qb = acc_t.cloneType + val qc = acc_t.cloneType + val qln2 = acc_t.cloneType + val qln2_inv = acc_t.cloneType +} + +class AccumulationLanes[T <: Data](num_stats: Int, acc_t: T, n_lanes: Int, latency: Int)(implicit ev: Arithmetic[T]) + extends Module { + // Each lane computes a sum, or an error-squared sum + + import ev._ + + class LaneOutput extends Bundle { + val result = acc_t.cloneType + val stats_id = UInt(log2Up(num_stats).W) + } + + val io = IO(new Bundle { + val ins = Flipped(Valid(new Bundle { + val len = UInt(log2Up(n_lanes+1).W) + val data = Vec(n_lanes, acc_t) + val mean = acc_t.cloneType + val max = acc_t.cloneType + val iexp_const = new IExpConst(acc_t) + val cmd = NormCmd() + val stats_id = UInt(log2Up(num_stats).W) + })) + + val out = Valid(new LaneOutput) + + val busy = Output(Bool()) + }) + + val cmd = io.ins.bits.cmd + val mean = io.ins.bits.mean + val iexp_c = io.ins.bits.iexp_const + + val data = io.ins.bits.data.zipWithIndex.map { case (d, i) => + val iexp_result = iexp(d - io.ins.bits.max, iexp_c.qln2, iexp_c.qln2_inv, iexp_c.qb, iexp_c.qc) + Mux(i.U < io.ins.bits.len, + MuxCase(d, Seq( + (cmd === NormCmd.VARIANCE || cmd === NormCmd.INV_STDDEV) -> (d-mean)*(d-mean), + (cmd === NormCmd.SUM_EXP || cmd === NormCmd.INV_SUM_EXP) -> + iexp_result //iexp(d - io.ins.bits.max, iexp_c.qln2, iexp_c.qln2_inv, iexp_c.qb, iexp_c.qc) + )).withWidthOf(acc_t), + d.zero) + } + + val result = data.reduce(_ + _) + + val pipe = Module(new Pipeline[LaneOutput](new LaneOutput, latency)()) + + pipe.io.in.valid := io.ins.valid + // io.ins.ready := pipe.io.in.ready + pipe.io.in.bits.result := result + pipe.io.in.bits.stats_id := io.ins.bits.stats_id + + io.out.valid := pipe.io.out.valid + pipe.io.out.ready := true.B + // pipe.io.out.ready := io.out.ready + io.out.bits := pipe.io.out.bits + + io.busy := pipe.io.busy +} + +class MaxLanes[T <: Data](num_stats: Int, acc_t: T, n_lanes: Int, latency: Int)(implicit ev: Arithmetic[T]) + extends Module { + // Each lane computes a sum, or an error-squared sum + + import ev._ + import NormCmd._ + + class LaneOutput extends Bundle { + val result = acc_t.cloneType + val stats_id = UInt(log2Up(num_stats).W) + } + + val io = IO(new Bundle { + val ins = Flipped(Valid(new Bundle { + val len = UInt(log2Up(n_lanes + 1).W) + val data = Vec(n_lanes, acc_t) + val stats_id = UInt(log2Up(num_stats).W) + })) + + val out = Valid(new LaneOutput) + + val busy = Output(Bool()) + }) + + val data = io.ins.bits.data.zipWithIndex.map { case (d, i) => + Mux(i.U < io.ins.bits.len, d.withWidthOf(acc_t), d.minimum) + } + + val result = data.reduce({ (max, x) => Mux(x > max, x, max) }) + + val pipe = Module(new Pipeline[LaneOutput](new LaneOutput, latency)()) + + pipe.io.in.valid := io.ins.valid + // io.ins.ready := pipe.io.in.ready + pipe.io.in.bits.result := result + pipe.io.in.bits.stats_id := io.ins.bits.stats_id + + io.out.valid := pipe.io.out.valid + pipe.io.out.ready := true.B + // pipe.io.out.ready := io.out.ready + io.out.bits := pipe.io.out.bits + + io.busy := pipe.io.busy +} + +class Normalizer[T <: Data, U <: Data](max_len: Int, num_reduce_lanes: Int, num_stats: Int, latency: Int, + fullDataType: Vec[Vec[T]], scale_t: U) + (implicit ev: Arithmetic[T]) extends Module { + import ev._ + val acc_t = fullDataType.head.head.cloneType + val vec_size = fullDataType.flatten.size + val n_lanes = if (num_reduce_lanes < 0) vec_size else num_reduce_lanes + + assert(isPow2(n_lanes)) + + val io = IO(new Bundle { + val in = Flipped(Decoupled(new NormalizedInput[T,U](max_len, num_stats, fullDataType, scale_t))) + val out = Decoupled(new NormalizedOutput(fullDataType, scale_t)) + }) + + object State extends ChiselEnum { + // NOTE: We assume that "idle" and "output" are the first two states. We also assume that all the enums on the same + // line keep the order below + val idle, output = Value + val get_sum = Value + val get_mean, waiting_for_mean = Value + val get_variance, waiting_for_variance, get_stddev, waiting_for_stddev, get_inv_stddev, waiting_for_inv_stddev = Value + val get_max = Value + val get_inv_sum_exp, waiting_for_inv_sum_exp = Value + } + import State._ + + // Buffers for normalization stats + class Stats extends Bundle { + val req = new NormalizedInput[T,U](max_len, num_stats, fullDataType, scale_t) + val state = State() + + // Running state + val sum = acc_t.cloneType + val count = UInt(16.W) // TODO magic number + val running_max = acc_t.cloneType + val max = acc_t.cloneType + + // Iterative state + val mean = acc_t.cloneType + val inv_stddev = acc_t.cloneType + val inv_sum_exp = acc_t.cloneType + + val elems_left = req.len.cloneType + + def vec_grouped = VecInit(req.acc_read_resp.data.flatten.grouped(n_lanes).map(v => VecInit(v)).toSeq) + def vec_groups_left = elems_left / n_lanes.U + (elems_left % n_lanes.U =/= 0.U) + + def cmd = req.cmd + + def waiting_for_lanes_to_drain = + (cmd === NormCmd.MEAN && (state === get_sum || state === get_mean)) || + (cmd === NormCmd.INV_STDDEV && (state === get_sum || state === get_variance)) || + (cmd === NormCmd.MAX && (state === get_max)) || + (cmd === NormCmd.INV_SUM_EXP && (state === get_sum)) + } + + val stats = Reg(Vec(num_stats, new Stats)) + val done_with_functional_units = Wire(Vec(num_stats, Bool())) + val next_states = Wire(Vec(num_stats, State())) + + (stats.map(_.state) zip next_states).foreach { case (s, ns) => s := ns } + + // IO + val in_stats_id = io.in.bits.stats_id + io.in.ready := (stats(in_stats_id).state === idle || done_with_functional_units(in_stats_id)) && + stats.map(!_.waiting_for_lanes_to_drain).reduce(_ && _) + + val out_stats_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => (s.state === output) -> i.U } + ) + + io.out.valid := stats(out_stats_id).state === output + io.out.bits.acc_read_resp := stats(out_stats_id).req.acc_read_resp + io.out.bits.mean := stats(out_stats_id).mean + io.out.bits.max := stats(out_stats_id).max + io.out.bits.inv_stddev := stats(out_stats_id).inv_stddev.asTypeOf(scale_t) + io.out.bits.inv_sum_exp := stats(out_stats_id).inv_sum_exp.asTypeOf(scale_t) + + // Lanes and functional units + val lanes = Module(new AccumulationLanes(num_stats, acc_t, n_lanes, latency)) + val max_lanes = Module(new MaxLanes(num_stats, acc_t, n_lanes, latency)) // TODO: change latency? + + { + // Lanes input + val in_lanes_stats_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => (s.state === get_sum) -> i.U } + ) + + val stat = stats(in_lanes_stats_id) + + val len = Mux(stat.elems_left % n_lanes.U === 0.U, n_lanes.U, stat.elems_left % n_lanes.U) + + lanes.io.ins.valid := stat.state === get_sum && stat.vec_groups_left > 0.U + lanes.io.ins.bits.data := stat.vec_grouped(stat.vec_groups_left-1.U) + lanes.io.ins.bits.mean := stat.mean + lanes.io.ins.bits.max := stat.max + + val iexp_const = Wire(new IExpConst(acc_t)) + iexp_const.qln2 := io.in.bits.acc_read_resp.iexp_qln2.asTypeOf(iexp_const.qln2) + iexp_const.qln2_inv := io.in.bits.acc_read_resp.iexp_qln2_inv.asTypeOf(iexp_const.qln2_inv) + iexp_const.qb := io.in.bits.acc_read_resp.igelu_qb.asTypeOf(iexp_const.qb) + iexp_const.qc := io.in.bits.acc_read_resp.igelu_qc.asTypeOf(iexp_const.qc) + + lanes.io.ins.bits.cmd := stat.cmd + lanes.io.ins.bits.len := len + lanes.io.ins.bits.stats_id := in_lanes_stats_id + lanes.io.ins.bits.iexp_const := iexp_const + + when (lanes.io.ins.fire()) { + stat.elems_left := stat.elems_left - len + } + } + + { + // Lanes output + val out_lanes_stats_id = lanes.io.out.bits.stats_id + + val stat = stats(out_lanes_stats_id) + + when (lanes.io.out.fire()) { + stat.sum := stat.sum + lanes.io.out.bits.result + } + } + + { + // Max lanes input + val max_in_lanes_stats_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => (s.state === get_max) -> i.U } + ) + + val stat = stats(max_in_lanes_stats_id) + + val len = Mux(stat.elems_left % n_lanes.U === 0.U, n_lanes.U, stat.elems_left % n_lanes.U) + + max_lanes.io.ins.valid := stat.state === get_max && stat.vec_groups_left > 0.U + max_lanes.io.ins.bits.data := stat.vec_grouped(stat.vec_groups_left-1.U) + max_lanes.io.ins.bits.len := len + max_lanes.io.ins.bits.stats_id := max_in_lanes_stats_id + + when (max_lanes.io.ins.fire()) { + stat.elems_left := stat.elems_left - len + } + } + + { + // Max lanes output + val max_out_lanes_stats_id = max_lanes.io.out.bits.stats_id + + val stat = stats(max_out_lanes_stats_id) + + when (max_lanes.io.out.fire()) { + stat.running_max := Mux(max_lanes.io.out.bits.result > stat.running_max, max_lanes.io.out.bits.result, stat.running_max) + //stat.max := Mux(max_lanes.io.out.bits.result > stat.max, max_lanes.io.out.bits.result, stat.max) + } + } + + val sum_to_divide_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => + (s.state === get_mean || s.state === get_variance) -> i.U } + ) + val sum_to_divide = stats(sum_to_divide_id).sum + val (divider_in, divider_out) = sum_to_divide.divider(stats.head.count).get + + { + // Divider input + val stat = stats(sum_to_divide_id) + + divider_in.valid := (stat.state === get_mean || stat.state === get_variance) && !lanes.io.busy + divider_in.bits := stat.count + } + + { + // Divider output + val waiting_for_divide_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => + (s.state === waiting_for_mean || s.state === waiting_for_variance) -> i.U } + ) + val stat = stats(waiting_for_divide_id) + + divider_out.ready := stat.state === waiting_for_mean || stat.state === waiting_for_variance + + when(stat.state === waiting_for_mean) { + stat.mean := divider_out.bits + }.elsewhen(stat.state === waiting_for_variance) { + stat.inv_stddev := divider_out.bits + } + } + + val variance_to_sqrt_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => + (s.state === get_stddev) -> i.U } + ) + val variance_to_sqrt = stats(variance_to_sqrt_id).inv_stddev + val (sqrt_in, sqrt_out) = variance_to_sqrt.sqrt.get + + { + // Sqrt input + val stat = stats(variance_to_sqrt_id) + + sqrt_in.valid := stat.state === get_stddev + } + + { + // Sqrt output + val waiting_for_sqrt_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => + (s.state === waiting_for_stddev) -> i.U } + ) + val stat = stats(waiting_for_sqrt_id) + + sqrt_out.ready := stat.state === waiting_for_stddev + + // TODO this fallback for stddev === 0 only works if acc_t is an SInt + assert(acc_t.isInstanceOf[SInt]) + + when (stat.state === waiting_for_stddev) { + stat.inv_stddev := Mux(sqrt_out.bits.asUInt() === acc_t.zero.asUInt(), + 1.S(acc_t.getWidth.W).asTypeOf(acc_t), + sqrt_out.bits + ) + } + } + + val stddev_to_inv_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => + (s.state === get_inv_stddev) -> i.U } + ) + val stddev_to_inv = stats(stddev_to_inv_id).inv_stddev + val (reciprocal_in, reciprocal_out) = stddev_to_inv.reciprocal(scale_t).get + + { + // Reciprocal input + val stat = stats(stddev_to_inv_id) + + reciprocal_in.valid := stat.state === get_inv_stddev + reciprocal_in.bits := DontCare + } + + { + // Reciprocal output + val waiting_for_reciprocal_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => + (s.state === waiting_for_inv_stddev) -> i.U } + ) + val stat = stats(waiting_for_reciprocal_id) + + reciprocal_out.ready := stat.state === waiting_for_inv_stddev + + when (stat.state === waiting_for_inv_stddev) { + stat.inv_stddev := reciprocal_out.bits.asTypeOf(stat.inv_stddev) + } + } + + val sum_exp_to_inv_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => + (s.state === get_inv_sum_exp) -> i.U } + ) + val sum_exp_to_inv = stats(sum_exp_to_inv_id).sum + val exp_divider_in = Wire(Decoupled(UInt(0.W))) + val exp_divider_out = Wire(Decoupled(scale_t.cloneType)) + + scale_t match { + case Float(expWidth, sigWidth) => + + exp_divider_in.bits := DontCare + + // We translate our integer to floating-point form so that we can use the hardfloat divider + def in_to_float(x: SInt) = { + val in_to_rec_fn = Module(new INToRecFN(intWidth = sum_exp_to_inv.getWidth, expWidth, sigWidth)) + in_to_rec_fn.io.signedIn := true.B + in_to_rec_fn.io.in := x.asUInt() + in_to_rec_fn.io.roundingMode := consts.round_near_even // consts.round_near_maxMag + in_to_rec_fn.io.detectTininess := consts.tininess_afterRounding + + in_to_rec_fn.io.out + } + + val self_rec = in_to_float(sum_exp_to_inv.asUInt().asSInt()) + val one_rec = in_to_float(127.S) // softmax maximum is 127 for signed int8 + + // Instantiate the hardloat divider + val divider = Module(new DivSqrtRecFN_small(expWidth, sigWidth, 0)) + + exp_divider_in.ready := divider.io.inReady + divider.io.inValid := exp_divider_in.valid + divider.io.sqrtOp := false.B + divider.io.a := one_rec + divider.io.b := self_rec + divider.io.roundingMode := consts.round_near_even + divider.io.detectTininess := consts.tininess_afterRounding + + exp_divider_out.valid := divider.io.outValid_div + exp_divider_out.bits := fNFromRecFN(expWidth, sigWidth, divider.io.out).asTypeOf(scale_t) + } + + + { + // Divider input + val stat = stats(sum_exp_to_inv_id) + + exp_divider_in.valid := (stat.state === get_inv_sum_exp) && !lanes.io.busy + exp_divider_in.bits := sum_exp_to_inv.asUInt() + } + + { + // Divider output + val waiting_for_divide_id = MuxCase((num_stats-1).U, + stats.zipWithIndex.map { case (s,i) => + (s.state === waiting_for_inv_sum_exp) -> i.U } + ) + val stat = stats(waiting_for_divide_id) + + exp_divider_out.ready := stat.state === waiting_for_inv_sum_exp + + when (stat.state === waiting_for_inv_sum_exp) { + stat.inv_sum_exp := exp_divider_out.bits.asTypeOf(stat.inv_sum_exp) + } + } + + // State transitions + for (((stat, next_state), id) <- (stats zip next_states).zipWithIndex) { + val state = stat.state + val cmd = stat.cmd + + val done = done_with_functional_units(id) + + when (state === idle) { + // We have a different "when" statement below to support the case where a new row is input into the normalizer + next_state := idle + done := DontCare + }.elsewhen(state === output) { + next_state := Mux(io.out.fire() && out_stats_id === id.U, idle, state) + done := io.out.fire() && out_stats_id === id.U + }.elsewhen(state === get_max) { + val is_last_lane_input = stat.vec_groups_left === 0.U || + (stat.vec_groups_left === 1.U && + max_lanes.io.ins.bits.stats_id === id.U && + max_lanes.io.ins.fire()) + + next_state := Mux( + is_last_lane_input, + MuxCase(state, Seq( + (cmd === NormCmd.MAX) -> idle, + (cmd === NormCmd.SUM_EXP || cmd === NormCmd.INV_SUM_EXP) -> get_sum + )), + state + ) + + done := is_last_lane_input && cmd === NormCmd.MAX + }.elsewhen(state === get_sum) { + val is_last_lane_input = stat.vec_groups_left === 0.U || + (stat.vec_groups_left === 1.U && + lanes.io.ins.bits.stats_id === id.U && + lanes.io.ins.fire()) + + next_state := Mux( + is_last_lane_input, + MuxCase(state, Seq( + (cmd === NormCmd.SUM || cmd === NormCmd.VARIANCE || cmd === NormCmd.SUM_EXP) -> idle, + (cmd === NormCmd.MEAN) -> get_mean, + (cmd === NormCmd.INV_STDDEV) -> get_variance, + (cmd === NormCmd.INV_SUM_EXP) -> get_inv_sum_exp, + )), + state + ) +// next_state := Mux(cmd === NormCmd.SUM || cmd === NormCmd.VARIANCE, +// Mux(is_last_lane_input, idle, state), +// Mux(is_last_lane_input, +// Mux(cmd === NormCmd.MEAN, get_mean, get_variance), +// state) +// ) + + done := is_last_lane_input && cmd =/= NormCmd.MEAN && cmd =/= NormCmd.INV_STDDEV && cmd =/= NormCmd.INV_SUM_EXP + }.elsewhen(state === get_mean || state === get_variance) { + next_state := Mux(divider_in.fire() && sum_to_divide_id === id.U, state.next, state) + done := false.B + }.elsewhen(state === waiting_for_mean) { + next_state := Mux(divider_out.fire(), idle, state) + done := divider_out.fire() + }.elsewhen(state === waiting_for_variance) { + next_state := Mux(divider_out.fire(), get_stddev, state) + done := false.B + }.elsewhen(state === get_stddev) { + next_state := Mux(sqrt_in.fire() && variance_to_sqrt_id === id.U, state.next, state) + done := false.B + }.elsewhen(state === waiting_for_stddev) { + next_state := Mux(sqrt_out.fire(), state.next, state) + done := false.B + }.elsewhen(state === get_inv_stddev) { + next_state := Mux(reciprocal_in.fire() && stddev_to_inv_id === id.U, state.next, state) + done := false.B + }.elsewhen(state === waiting_for_inv_stddev) { + next_state := Mux(reciprocal_out.fire(), idle, state) + done := reciprocal_out.fire() + }.elsewhen(state === get_inv_sum_exp) { + next_state := Mux(exp_divider_in.fire() && sum_exp_to_inv_id === id.U, state.next, state) + done := false.B + }.elsewhen(state === waiting_for_inv_sum_exp) { + next_state := Mux(exp_divider_out.fire(), idle, state) + done := exp_divider_out.fire() + }.otherwise { + assert(false.B, "invalid state in Normalizer") + next_state := DontCare + done := DontCare + } + + when (io.in.fire() && in_stats_id === id.U) { + next_state := Mux(io.in.bits.cmd === NormCmd.RESET, output, + Mux(io.in.bits.cmd === NormCmd.MAX, get_max, get_sum)) + when (io.in.bits.cmd === NormCmd.SUM_EXP) { + stat.max := stat.running_max + } + } + } + + // Update stats variables + for (((stat, next_state), id) <- (stats zip next_states).zipWithIndex) { + val state = stat.state + + val reset_running_state = + state === output || + (state === get_mean && next_state =/= get_mean) || + (state === get_variance && next_state =/= get_variance) + + val is_input = io.in.fire() && in_stats_id === id.U + + when (is_input) { + stat.req := io.in.bits + stat.count := stat.count + io.in.bits.len + stat.elems_left := io.in.bits.len + } + + when(reset_running_state) { + stat.sum := acc_t.zero + stat.count := Mux(is_input, io.in.bits.len, 0.U) + } + + when (state =/= get_inv_sum_exp && next_state === get_inv_sum_exp) { + stat.running_max := acc_t.minimum + } + } + + dontTouch(stats) + + // Assertions + assert(PopCount(stats.map(s => s.state === waiting_for_mean || s.state === waiting_for_variance)) <= 1.U, "we don't support pipelining the divider/sqrt-unit/inv-unit right now") + assert(PopCount(stats.map(_.state === waiting_for_stddev)) <= 1.U, "we don't support pipelining the divider/sqrt-unit/inv-unit right now") + assert(PopCount(stats.map(_.state === waiting_for_inv_stddev)) <= 1.U, "we don't support pipelining the divider/sqrt-unit/inv-unit right now") + assert(PopCount(stats.map(_.state === output)) <= 1.U, "multiple outputs at same time") + assert(acc_t.getWidth == scale_t.getWidth, "we use the same variable to hold both the variance and the inv-stddev, so we need them to see the width") + + // Resets + when (reset.asBool()) { + stats.foreach(_.state := idle) + stats.foreach(_.sum := acc_t.zero) + stats.foreach(_.max := acc_t.minimum) + stats.foreach(_.running_max := acc_t.minimum) + stats.foreach(_.count := 0.U) + stats.foreach(_.inv_sum_exp := acc_t.zero) + } +} + +object Normalizer { + def apply[T <: Data, U <: Data](is_passthru: Boolean, max_len: Int, num_reduce_lanes: Int, num_stats: Int, + latency: Int, fullDataType: Vec[Vec[T]], scale_t: U)(implicit ev: Arithmetic[T]): + (DecoupledIO[NormalizedInput[T,U]], DecoupledIO[NormalizedOutput[T,U]]) = { + if (is_passthru) { + passthru(max_len = max_len, num_stats = num_stats, fullDataType = fullDataType, scale_t = scale_t) + } else { + gen(max_len = max_len, num_reduce_lanes = num_reduce_lanes, num_stats = num_stats, latency = latency, + fullDataType = fullDataType, scale_t = scale_t) + } + } + + def gen[T <: Data, U <: Data](max_len: Int, num_reduce_lanes: Int, num_stats: Int, latency: Int, + fullDataType: Vec[Vec[T]], scale_t: U)(implicit ev: Arithmetic[T]): (DecoupledIO[NormalizedInput[T,U]], DecoupledIO[NormalizedOutput[T,U]]) = { + val norm_unit_module = Module(new Normalizer(max_len, num_reduce_lanes, num_stats, latency, fullDataType, scale_t)) + (norm_unit_module.io.in, norm_unit_module.io.out) + } + + def passthru[T <: Data, U <: Data](max_len: Int, num_stats: Int, fullDataType: Vec[Vec[T]], scale_t: U) + (implicit ev: Arithmetic[T]): (DecoupledIO[NormalizedInput[T,U]], DecoupledIO[NormalizedOutput[T,U]]) = { + + val norm_unit_passthru_q = Module(new Queue(new NormalizedInput[T,U](max_len, num_stats, fullDataType, scale_t), 2)) + val norm_unit_passthru_out = Wire(Decoupled(new NormalizedOutput(fullDataType, scale_t))) + + norm_unit_passthru_out.valid := norm_unit_passthru_q.io.deq.valid + norm_unit_passthru_out.bits.acc_read_resp := norm_unit_passthru_q.io.deq.bits.acc_read_resp + norm_unit_passthru_out.bits.mean := DontCare + norm_unit_passthru_out.bits.max := DontCare + norm_unit_passthru_out.bits.inv_stddev := DontCare + norm_unit_passthru_out.bits.inv_sum_exp := DontCare + + norm_unit_passthru_q.io.deq.ready := norm_unit_passthru_out.ready + + (norm_unit_passthru_q.io.enq, norm_unit_passthru_out) + } +} diff --git a/src/main/scala/gemmini/ReservationStation.scala b/src/main/scala/gemmini/ReservationStation.scala index 8bb03415..68d0e6e7 100644 --- a/src/main/scala/gemmini/ReservationStation.scala +++ b/src/main/scala/gemmini/ReservationStation.scala @@ -23,7 +23,8 @@ class ReservationStationIssue[T <: Data](cmd_t: T, id_width: Int) extends Bundle } // TODO we don't need to store the full command in here. We should be able to release the command directly into the relevant controller and only store the associated metadata in the ROB. This would reduce the size considerably -class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], cmd_t: GemminiCmd) extends Module { +class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], + cmd_t: GemminiCmd) extends Module { import config._ val block_rows = tileRows * meshRows @@ -179,7 +180,6 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G val alloc_fire = io.alloc.fire() - dontTouch(new_entry) io.alloc.ready := false.B when (io.alloc.valid) { val spAddrBits = 32 @@ -251,7 +251,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G val mvout_cols = cmd.rs2(32 + mvout_cols_bits - 1, 32) val mvout_rows = cmd.rs2(48 + mvout_rows_bits - 1, 48) - val mvout_mats = mvout_cols / block_cols.U + (mvout_cols % block_cols.U =/= 0.U) + val mvout_mats = mvout_cols / block_cols.U(mvout_cols_bits.W) + (mvout_cols % block_cols.U =/= 0.U) val total_mvout_rows = ((mvout_mats - 1.U) * block_stride) + mvout_rows op2.bits.end := op2.bits.start + total_mvout_rows @@ -273,7 +273,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G val mvin_cols = cmd.rs2(32 + mvin_cols_bits - 1, 32) val mvin_rows = cmd.rs2(48 + mvin_rows_bits - 1, 48) - val mvin_mats = mvin_cols / block_cols.U + (mvin_cols % block_cols.U =/= 0.U) + val mvin_mats = mvin_cols / block_cols.U(mvin_cols_bits.W) + (mvin_cols % block_cols.U =/= 0.U) val total_mvin_rows = ((mvin_mats - 1.U) * block_stride) + mvin_rows // TODO We have to know how the LoopConv's internals work here. Our abstractions are leaking @@ -293,9 +293,9 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G } val is_load = funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD) - val is_store = funct === STORE_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_STORE) - val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_EX || config_cmd_type === CONFIG_IM2COL)) - val is_im2col = funct === CONFIG_CMD && config_cmd_type === CONFIG_IM2COL // im2col commands are a subset of ex commands, so they still go in the ex queue + val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX) + val is_store = funct === STORE_CMD || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_STORE || config_cmd_type === CONFIG_NORM)) + val is_norm = funct === CONFIG_CMD && config_cmd_type === CONFIG_NORM // normalization commands are a subset of store commands, so they still go in the store queue new_entry.q := Mux1H(Seq( is_load -> ldq, @@ -364,7 +364,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G } when (io.alloc.fire) { - when (new_entry.is_config && new_entry.q === exq && !is_im2col) { + when (new_entry.is_config && new_entry.q === exq) { a_stride := new_entry.cmd.cmd.rs1(31, 16) // TODO magic numbers // TODO this needs to be kept in sync with ExecuteController.scala c_stride := new_entry.cmd.cmd.rs2(63, 48) // TODO magic numbers // TODO this needs to be kept in sync with ExecuteController.scala val set_only_strides = new_entry.cmd.cmd.rs1(7) // TODO magic numbers @@ -377,7 +377,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G val repeat_pixels = maxOf(new_entry.cmd.cmd.rs1(8 + pixel_repeats_bits - 1, 8), 1.U) // TODO we use a default value of pixel repeats here, for backwards compatibility. However, we should deprecate and remove this default value eventually ld_block_strides(id) := block_stride ld_pixel_repeats(id) := repeat_pixels - 1.U - }.elsewhen(new_entry.is_config && new_entry.q === stq) { + }.elsewhen(new_entry.is_config && new_entry.q === stq && !is_norm) { val pool_stride = new_entry.cmd.cmd.rs1(5, 4) // TODO magic numbers pooling_is_enabled := pool_stride =/= 0.U }.elsewhen(funct === PRELOAD_CMD) { diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index 008dc990..70c9140f 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -1,3 +1,4 @@ + package gemmini import chisel3._ @@ -6,7 +7,7 @@ import freechips.rocketchip.config.Parameters import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp} import freechips.rocketchip.rocket._ import freechips.rocketchip.tile._ -import freechips.rocketchip.tilelink.{TLIdentityNode, TLXbar, TLBuffer} +import freechips.rocketchip.tilelink._ import Util._ @@ -26,13 +27,18 @@ class ScratchpadMemReadRequest[U <: Data](local_addr_t: LocalAddr, scale_t_bits: } -class ScratchpadMemWriteRequest(local_addr_t: LocalAddr, scale_t_bits: Int) +class ScratchpadMemWriteRequest(local_addr_t: LocalAddr, acc_t_bits: Int, scale_t_bits: Int) (implicit p: Parameters) extends CoreBundle { val vaddr = UInt(coreMaxAddrBits.W) val laddr = local_addr_t.cloneType - val acc_act = UInt(2.W) // TODO don't use a magic number for the width here + val acc_act = UInt(Activation.bitwidth.W) // TODO don't use a magic number for the width here val acc_scale = UInt(scale_t_bits.W) + val acc_igelu_qb = UInt(acc_t_bits.W) + val acc_igelu_qc = UInt(acc_t_bits.W) + val acc_iexp_qln2 = UInt(acc_t_bits.W) + val acc_iexp_qln2_inv = UInt(acc_t_bits.W) + val acc_norm_stats_id = UInt(8.W) // TODO magic number val len = UInt(16.W) // TODO don't use a magic number for the width here val block = UInt(8.W) // TODO don't use a magic number for the width here @@ -58,14 +64,12 @@ class ScratchpadMemReadResponse extends Bundle { class ScratchpadReadMemIO[U <: Data](local_addr_t: LocalAddr, scale_t_bits: Int)(implicit p: Parameters) extends CoreBundle { val req = Decoupled(new ScratchpadMemReadRequest(local_addr_t, scale_t_bits)) val resp = Flipped(Valid(new ScratchpadMemReadResponse)) - } -class ScratchpadWriteMemIO(local_addr_t: LocalAddr, scale_t_bits: Int) +class ScratchpadWriteMemIO(local_addr_t: LocalAddr, acc_t_bits: Int, scale_t_bits: Int) (implicit p: Parameters) extends CoreBundle { - val req = Decoupled(new ScratchpadMemWriteRequest(local_addr_t, scale_t_bits)) + val req = Decoupled(new ScratchpadMemWriteRequest(local_addr_t, acc_t_bits, scale_t_bits)) val resp = Flipped(Valid(new ScratchpadMemWriteResponse)) - } class ScratchpadReadReq(val n: Int) extends Bundle { @@ -195,7 +199,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, xbar_node := TLBuffer() := reader.node // TODO xbar_node := TLBuffer() := writer.node - id_node := TLBuffer() := xbar_node + id_node := TLWidthWidget(config.dma_buswidth/8) := TLBuffer() := xbar_node lazy val module = new LazyModuleImp(this) with HasCoreParameters { @@ -203,7 +207,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // DMA ports val dma = new Bundle { val read = Flipped(new ScratchpadReadMemIO(local_addr_t, mvin_scale_t_bits)) - val write = Flipped(new ScratchpadWriteMemIO(local_addr_t, acc_scale_t_bits)) + val write = Flipped(new ScratchpadWriteMemIO(local_addr_t, accType.getWidth, acc_scale_t_bits)) } // SRAM ports @@ -215,7 +219,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // Accumulator ports val acc = new Bundle { val read_req = Flipped(Vec(acc_banks, Decoupled(new AccumulatorReadReq( - acc_bank_entries, log2Up(accType.getWidth), acc_scale_t.asInstanceOf[V] + acc_bank_entries, accType, acc_scale_t.asInstanceOf[V] )))) val read_resp = Vec(acc_banks, Decoupled(new AccumulatorScaleResp( Vec(meshColumns, Vec(tileColumns, inputType)), @@ -242,25 +246,37 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, }) val write_dispatch_q = Queue(io.dma.write.req) - write_dispatch_q.ready := false.B - // Write scale queue is necessary to maintain in-order requests to accumulator scale unit + // Write norm/scale queues are necessary to maintain in-order requests to accumulator norm/scale units // Writes from main SPAD just flow directly between scale_q and issue_q, while writes // From acc are ordered - val write_scale_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, acc_scale_t_bits), spad_read_delay)) - val write_issue_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, acc_scale_t_bits), spad_read_delay+1, pipe=true)) + val write_norm_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+2)) + val write_scale_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+2)) + val write_issue_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+1, pipe=true)) val read_issue_q = Module(new Queue(new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits), spad_read_delay+1, pipe=true)) // TODO can't this just be a normal queue? + write_dispatch_q.ready := false.B + + write_norm_q.io.enq.valid := false.B + write_norm_q.io.enq.bits := write_dispatch_q.bits + write_norm_q.io.deq.ready := false.B + write_scale_q.io.enq.valid := false.B - write_scale_q.io.enq.bits := write_dispatch_q.bits + write_scale_q.io.enq.bits := write_norm_q.io.deq.bits write_scale_q.io.deq.ready := false.B write_issue_q.io.enq.valid := false.B write_issue_q.io.enq.bits := write_scale_q.io.deq.bits - // Garbage can immediately fire between dispatch_q and scale_q + // Garbage can immediately fire from dispatch_q -> norm_q when (write_dispatch_q.bits.laddr.is_garbage()) { - write_scale_q.io.enq <> write_dispatch_q + write_norm_q.io.enq <> write_dispatch_q } + + // Non-acc or garbage can immediately fire between norm_q and scale_q + when (write_norm_q.io.deq.bits.laddr.is_garbage() || !write_norm_q.io.deq.bits.laddr.is_acc_addr) { + write_scale_q.io.enq <> write_norm_q.io.deq + } + // Non-acc or garbage can immediately fire between scale_q and issue_q when (write_scale_q.io.deq.bits.laddr.is_garbage() || !write_scale_q.io.deq.bits.laddr.is_acc_addr) { write_issue_q.io.enq <> write_scale_q.io.deq @@ -425,7 +441,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, writer.module.io.flush := io.flush reader.module.io.flush := io.flush - io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid + io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid || write_norm_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid val spad_mems = { val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank( @@ -444,7 +460,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val exread = ex_read_req.valid // TODO we tie the write dispatch queue's, and write issue queue's, ready and valid signals together here - val dmawrite = write_dispatch_q.valid && write_scale_q.io.enq.ready && + val dmawrite = write_dispatch_q.valid && write_norm_q.io.enq.ready && !write_dispatch_q.bits.laddr.is_garbage() && !(bio.write.en && config.sp_singleported.B) && !write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.sp_bank() === i.U @@ -462,7 +478,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, when (bio.read.req.fire) { write_dispatch_q.ready := true.B - write_scale_q.io.enq.valid := true.B + write_norm_q.io.enq.valid := true.B io.dma.write.resp.valid := true.B } @@ -543,34 +559,73 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val acc_row_t = Vec(meshColumns, Vec(tileColumns, accType)) val spad_row_t = Vec(meshColumns, Vec(tileColumns, inputType)) +// val acc_norm_unit = Module(new Normalizer( +// max_len = block_cols, +// num_reduce_lanes = -1, +// num_stats = 4, +// latency = 4, +// fullDataType = acc_row_t, +// scale_t = acc_scale_t, +// )) + + val (acc_norm_unit_in, acc_norm_unit_out) = Normalizer( + is_passthru = !config.has_normalizations, + max_len = block_cols, + num_reduce_lanes = -1, + num_stats = 4, + latency = 4, + fullDataType = acc_row_t, + scale_t = acc_scale_t, + ) + + acc_norm_unit_in.valid := false.B + acc_norm_unit_in.bits.len := write_norm_q.io.deq.bits.len + acc_norm_unit_in.bits.stats_id := write_norm_q.io.deq.bits.acc_norm_stats_id + acc_norm_unit_in.bits.cmd := write_norm_q.io.deq.bits.laddr.norm_cmd + acc_norm_unit_in.bits.acc_read_resp := DontCare + val acc_scale_unit = Module(new AccumulatorScale( acc_row_t, spad_row_t, acc_scale_t.asInstanceOf[V], - log2Up(accType.getWidth), acc_read_small_width, acc_read_full_width, acc_scale_func, acc_scale_num_units, acc_scale_latency, has_nonlinear_activations, + has_normalizations, )) - acc_scale_unit.io.in.valid := false.B - acc_scale_unit.io.in.bits := DontCare - val dma_resp_ready = ( - writer.module.io.req.ready && - write_issue_q.io.deq.bits.laddr.is_acc_addr && - !write_issue_q.io.deq.bits.laddr.is_garbage() - ) + val acc_waiting_to_be_scaled = write_scale_q.io.deq.valid && + !write_scale_q.io.deq.bits.laddr.is_garbage() && + write_scale_q.io.deq.bits.laddr.is_acc_addr && + write_issue_q.io.enq.ready + + acc_norm_unit_out.ready := acc_scale_unit.io.in.ready && acc_waiting_to_be_scaled + acc_scale_unit.io.in.valid := acc_norm_unit_out.valid && acc_waiting_to_be_scaled + acc_scale_unit.io.in.bits := acc_norm_unit_out.bits + + when (acc_scale_unit.io.in.fire()) { + write_issue_q.io.enq <> write_scale_q.io.deq + } + acc_scale_unit.io.out.ready := false.B + + val dma_resp_ready = + writer.module.io.req.ready && + write_issue_q.io.deq.bits.laddr.is_acc_addr && + !write_issue_q.io.deq.bits.laddr.is_garbage() + when (acc_scale_unit.io.out.bits.fromDMA && dma_resp_ready) { + // Send the acc-scale result into the DMA acc_scale_unit.io.out.ready := true.B writeData.valid := acc_scale_unit.io.out.valid writeData.bits := acc_scale_unit.io.out.bits.data.asUInt fullAccWriteData := acc_scale_unit.io.out.bits.full_data.asUInt } for (i <- 0 until acc_banks) { + // Send the acc-sccale result to the ExController io.acc.read_resp(i).valid := false.B io.acc.read_resp(i).bits := acc_scale_unit.io.out.bits when (!acc_scale_unit.io.out.bits.fromDMA && acc_scale_unit.io.out.bits.acc_bank_id === i.U) { @@ -608,18 +663,21 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val exread = ex_read_req.valid // TODO we tie the write dispatch queue's, and write issue queue's, ready and valid signals together here - val dmawrite = write_dispatch_q.valid && write_scale_q.io.enq.ready && + val dmawrite = write_dispatch_q.valid && write_norm_q.io.enq.ready && !write_dispatch_q.bits.laddr.is_garbage() && write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.acc_bank() === i.U bio.read.req.valid := exread || dmawrite - bio.read.req.bits.relu6_shift := ex_read_req.bits.relu6_shift ex_read_req.ready := bio.read.req.ready // The ExecuteController gets priority when reading from accumulator banks when (exread) { bio.read.req.bits.addr := ex_read_req.bits.addr bio.read.req.bits.act := ex_read_req.bits.act + bio.read.req.bits.igelu_qb := ex_read_req.bits.igelu_qb + bio.read.req.bits.igelu_qc := ex_read_req.bits.igelu_qc + bio.read.req.bits.iexp_qln2 := ex_read_req.bits.iexp_qln2 + bio.read.req.bits.iexp_qln2_inv := ex_read_req.bits.iexp_qln2_inv bio.read.req.bits.scale := ex_read_req.bits.scale bio.read.req.bits.full := false.B bio.read.req.bits.fromDMA := false.B @@ -627,12 +685,16 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.read.req.bits.addr := write_dispatch_q.bits.laddr.acc_row() bio.read.req.bits.full := write_dispatch_q.bits.laddr.read_full_acc_row bio.read.req.bits.act := write_dispatch_q.bits.acc_act + bio.read.req.bits.igelu_qb := write_dispatch_q.bits.acc_igelu_qb.asTypeOf(bio.read.req.bits.igelu_qb) + bio.read.req.bits.igelu_qc := write_dispatch_q.bits.acc_igelu_qc.asTypeOf(bio.read.req.bits.igelu_qc) + bio.read.req.bits.iexp_qln2 := write_dispatch_q.bits.acc_iexp_qln2.asTypeOf(bio.read.req.bits.iexp_qln2) + bio.read.req.bits.iexp_qln2_inv := write_dispatch_q.bits.acc_iexp_qln2_inv.asTypeOf(bio.read.req.bits.iexp_qln2_inv) bio.read.req.bits.scale := write_dispatch_q.bits.acc_scale.asTypeOf(bio.read.req.bits.scale) bio.read.req.bits.fromDMA := true.B when (bio.read.req.fire) { write_dispatch_q.ready := true.B - write_scale_q.io.enq.valid := true.B + write_norm_q.io.enq.valid := true.B io.dma.write.resp.valid := true.B } @@ -641,22 +703,24 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, } bio.read.resp.ready := false.B - when (write_scale_q.io.deq.valid && - acc_scale_unit.io.in.ready && - bio.read.resp.valid && - write_issue_q.io.enq.ready && - write_scale_q.io.deq.bits.laddr.is_acc_addr && - !write_scale_q.io.deq.bits.laddr.is_garbage() && - write_scale_q.io.deq.bits.laddr.acc_bank() === i.U) { - write_scale_q.io.deq.ready := true.B - acc_scale_unit.io.in.valid := true.B + when (write_norm_q.io.deq.valid && + acc_norm_unit_in.ready && + bio.read.resp.valid && + write_scale_q.io.enq.ready && + write_norm_q.io.deq.bits.laddr.is_acc_addr && + !write_norm_q.io.deq.bits.laddr.is_garbage() && + write_norm_q.io.deq.bits.laddr.acc_bank() === i.U) + { + write_norm_q.io.deq.ready := true.B + acc_norm_unit_in.valid := true.B bio.read.resp.ready := true.B - write_issue_q.io.enq.valid := true.B - acc_scale_unit.io.in.bits := bio.read.resp.bits - acc_scale_unit.io.in.bits.acc_bank_id := i.U - } + // Some normalizer commands don't write to main memory, so they don't need to be passed on to the scaling units + write_scale_q.io.enq.valid := NormCmd.writes_to_main_memory(write_norm_q.io.deq.bits.laddr.norm_cmd) + acc_norm_unit_in.bits.acc_read_resp := bio.read.resp.bits + acc_norm_unit_in.bits.acc_read_resp.acc_bank_id := i.U + } } // Writing to the accumulator banks @@ -682,7 +746,6 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // We need to make sure that we don't try to return a dma read resp from both mvin_scale and mvin_scale_acc // at the same time. mvin_scale always gets priority in this cases - // val spad_last = mvin_scale_out.valid && mvin_scale_out.bits.last && !mvin_scale_out.bits.tag.is_acc val spad_last = mvin_scale_pixel_repeater.io.resp.valid && mvin_scale_pixel_repeater.io.resp.bits.last && !mvin_scale_pixel_repeater.io.resp.bits.tag.is_acc val dmaread = (from_mvin_scale || from_mvin_scale_acc) && diff --git a/src/main/scala/gemmini/StoreController.scala b/src/main/scala/gemmini/StoreController.scala index 692a8e04..c9e4fdbb 100644 --- a/src/main/scala/gemmini/StoreController.scala +++ b/src/main/scala/gemmini/StoreController.scala @@ -11,14 +11,14 @@ import midas.targetutils.PerfCounter // TODO this is almost a complete copy of LoadController. We should combine them into one class // TODO deal with errors when reading scratchpad responses -class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], coreMaxAddrBits: Int, local_addr_t: LocalAddr) - (implicit p: Parameters) extends Module { +class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], + coreMaxAddrBits: Int, local_addr_t: LocalAddr)(implicit p: Parameters) extends Module { import config._ val io = IO(new Bundle { val cmd = Flipped(Decoupled(new GemminiCmd(reservation_station_entries))) - val dma = new ScratchpadWriteMemIO(local_addr_t, acc_scale_t_bits) + val dma = new ScratchpadWriteMemIO(local_addr_t, accType.getWidth, acc_scale_t_bits) val completed = Decoupled(UInt(log2Up(reservation_station_entries).W)) @@ -42,7 +42,12 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val block_cols = meshColumns * tileColumns val max_blocks = (dma_maxbytes / (block_cols * inputType.getWidth / 8)) max 1 - val activation = Reg(UInt(GemminiISA.CONFIG_MVOUT_RS1_ACTIVATION_WIDTH.W)) + val activation = Reg(UInt(Activation.bitwidth.W)) // TODO magic number + val igelu_qb = Reg(accType) + val igelu_qc = Reg(accType) + val iexp_qln2 = Reg(accType) + val iexp_qln2_inv = Reg(accType) + val norm_stats_id = Reg(UInt(8.W)) // TODO magic number val acc_scale = Reg(acc_scale_t) //val row_counter = RegInit(0.U(log2Ceil(block_rows).W)) @@ -83,10 +88,11 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val localaddr = mvout_rs2.local_addr val cols = mvout_rs2.num_cols val rows = mvout_rs2.num_rows - val blocks = (cols / block_cols.U) + (cols % block_cols.U =/= 0.U) + val blocks = (cols / block_cols.U(cols.getWidth.W)) + (cols % block_cols.U =/= 0.U) val config_mvout_rs1 = cmd.bits.cmd.rs1.asTypeOf(new ConfigMvoutRs1) val config_mvout_rs2 = cmd.bits.cmd.rs2.asTypeOf(new ConfigMvoutRs2(acc_scale_t_bits, 32)) + val config_cmd_type = config_mvout_rs1.cmd_type val config_stride = config_mvout_rs2.stride val config_activation = config_mvout_rs1.activation val config_acc_scale = config_mvout_rs2.acc_scale @@ -100,10 +106,22 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val config_upad = config_mvout_rs1.upad val config_lpad = config_mvout_rs1.lpad + val config_norm_rs1 = cmd.bits.cmd.rs1.asTypeOf(new ConfigNormRs1(accType.getWidth)) + val config_norm_rs2 = cmd.bits.cmd.rs2.asTypeOf(new ConfigNormRs2(accType.getWidth)) + val config_stats_id = config_norm_rs1.norm_stats_id + val config_activation_msb = config_norm_rs1.act_msb + val config_set_stats_id_only = config_norm_rs1.set_stats_id_only + val config_iexp_q_const_type = config_norm_rs1.q_const_type + val config_iexp_q_const = config_norm_rs1.q_const + val config_igelu_qb = config_norm_rs2.qb + val config_igelu_qc = config_norm_rs2.qc + + assert(config_norm_rs1.cmd_type === config_mvout_rs1.cmd_type) + val mstatus = cmd.bits.cmd.status val current_vaddr = vaddr + row_counter * stride - val current_localaddr = localaddr + (block_counter * block_stride + row_counter) + val current_localaddr = WireInit(localaddr + (block_counter * block_stride + row_counter)) val pool_row_addr = localaddr + (orow * pool_ocols +& ocol) when (orow_is_negative || ocol_is_negative || orow >= pool_orows || ocol >= pool_ocols) { @@ -112,8 +130,9 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val pool_vaddr = vaddr + (porow_counter * pool_out_dim + pocol_counter) * stride // TODO get rid of these multiplications - val DoConfig = cmd.bits.cmd.inst.funct === CONFIG_CMD - val DoStore = !DoConfig // TODO change this if more commands are added + val DoConfig = cmd.bits.cmd.inst.funct === CONFIG_CMD && config_cmd_type === CONFIG_STORE + val DoConfigNorm = config.has_normalizations.B && cmd.bits.cmd.inst.funct === CONFIG_CMD && config_cmd_type === CONFIG_NORM + val DoStore = !DoConfig && !DoConfigNorm cmd.ready := false.B @@ -140,8 +159,15 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm io.dma.req.bits.vaddr := Mux(pooling_is_enabled || mvout_1d_enabled, pool_vaddr, current_vaddr) io.dma.req.bits.laddr := Mux(pooling_is_enabled, pool_row_addr, current_localaddr) //Todo: laddr for 1D? + io.dma.req.bits.laddr.norm_cmd := Mux(block_counter === blocks - 1.U, current_localaddr.norm_cmd, + NormCmd.non_reset_version(current_localaddr.norm_cmd)) io.dma.req.bits.acc_act := activation + io.dma.req.bits.acc_igelu_qb := igelu_qb.asTypeOf(io.dma.req.bits.acc_igelu_qb) + io.dma.req.bits.acc_igelu_qc := igelu_qc.asTypeOf(io.dma.req.bits.acc_igelu_qc) + io.dma.req.bits.acc_iexp_qln2 := iexp_qln2.asTypeOf(io.dma.req.bits.acc_iexp_qln2) + io.dma.req.bits.acc_iexp_qln2_inv := iexp_qln2_inv.asTypeOf(io.dma.req.bits.acc_iexp_qln2_inv) + io.dma.req.bits.acc_norm_stats_id := norm_stats_id io.dma.req.bits.acc_scale := acc_scale.asTypeOf(io.dma.req.bits.acc_scale) io.dma.req.bits.len := Mux(block_counter === blocks - 1.U, ((cols - 1.U) % block_cols.U) + 1.U, block_cols.U) @@ -221,10 +247,24 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm } cmd.ready := true.B } - .elsewhen(DoStore && cmd_tracker.io.alloc.fire()) { - val next_state = Mux(pooling_is_enabled, pooling, sending_rows) - control_state := Mux(io.dma.req.fire, next_state, waiting_for_dma_req_ready) + .elsewhen(config.has_normalizations.B && DoConfigNorm) { + when (!config_set_stats_id_only.asBool()) { + igelu_qb := config_igelu_qb.asTypeOf(igelu_qb) + igelu_qc := config_igelu_qc.asTypeOf(igelu_qc) + when(config_iexp_q_const_type === 0.U) { + iexp_qln2 := config_iexp_q_const.asTypeOf(iexp_qln2) + }.elsewhen(config_iexp_q_const_type === 1.U) { + iexp_qln2_inv := config_iexp_q_const.asTypeOf(iexp_qln2_inv) + } + activation := Cat(config_activation_msb, activation(1, 0)) // TODO: magic number } + norm_stats_id := config_stats_id + cmd.ready := true.B + } + .elsewhen(DoStore && cmd_tracker.io.alloc.fire()) { + val next_state = Mux(pooling_is_enabled, pooling, sending_rows) + control_state := Mux(io.dma.req.fire, next_state, waiting_for_dma_req_ready) + } } } @@ -260,6 +300,17 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm } } + // Optimizations when features are disabled + if (!config.has_normalizations) { + current_localaddr.norm_cmd := NormCmd.RESET + + igelu_qb := DontCare + igelu_qc := DontCare + iexp_qln2 := DontCare + iexp_qln2_inv := DontCare + norm_stats_id := 0.U + } + // Performance counter CounterEventIO.init(io.counter) io.counter.connectEventSignal(CounterEvent.STORE_ACTIVE_CYCLE, control_state === sending_rows || control_state === pooling) diff --git a/src/main/scala/gemmini/ZeroWriter.scala b/src/main/scala/gemmini/ZeroWriter.scala index a5c10abe..a1834a41 100644 --- a/src/main/scala/gemmini/ZeroWriter.scala +++ b/src/main/scala/gemmini/ZeroWriter.scala @@ -1,3 +1,4 @@ + package gemmini import chisel3._ @@ -40,7 +41,14 @@ class ZeroWriter[T <: Data, U <: Data, V <: Data, Tag <: Data](config: GemminiAr io.req.ready := !req.valid io.resp.valid := req.valid - io.resp.bits.laddr := req.bits.laddr + req.bits.block_stride * (col_counter / block_cols.U) + io.resp.bits.laddr := req.bits.laddr + req.bits.block_stride * { + // This code block was originally just "col_counter / block_cols.U". We + // changed it to satisfy Verilator's linter + if (col_counter.getWidth >= log2Ceil(block_cols+1)) + (col_counter / block_cols.U(col_counter.getWidth.W)) + else + 0.U + } io.resp.bits.mask.zipWithIndex.foreach { case (m, i) => m := col_counter + i.U < req.bits.cols } io.resp.bits.last := col_counter +& block_cols.U >= req.bits.cols io.resp.bits.tag := req.bits.tag