Skip to content

Instantly share code, notes, and snippets.

@ayah527

ayah527/argmax.s Secret

Last active May 8, 2021 20:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ayah527/09fb5ac6df595f153018eb4d0323a9ba to your computer and use it in GitHub Desktop.
Save ayah527/09fb5ac6df595f153018eb4d0323a9ba to your computer and use it in GitHub Desktop.
A neural network in RISC-V Assembly that classifies hand-written digits and outputs the correct digits. Implemented functions to handle ReLU, Dot Product, ArgMax, Matrix Multiplication, and reading and writing image pixel matrices into binary.
.globl argmax
.text
# =================================================================
# FUNCTION: Given a int vector, return the index of the largest
# element. If there are multiple, return the one
# with the smallest index.
# Arguments:
# a0 (int*) is the pointer to the start of the vector
# a1 (int) is the # of elements in the vector
# Returns:
# a0 (int) is the first index of the largest element
# Exceptions:
# - If the length of the vector is less than 1,
# this function terminates the program with error code 120.
# =================================================================
argmax:
beq a1, x0, exit
# Prologue
addi sp, sp, -32
# s0: pointer to list, s1: current size, s2: current item, s3: current index,
# s4: current max item, s5: max index
sw s0, 0(sp)
sw s1, 4(sp)
sw s2, 8(sp)
sw s3, 12(sp)
sw s4, 16(sp)
sw s5, 20(sp)
j loop_start
loop_start:
# load everything into save registers
mv s0, a0
add s1, a1, x0
lw s2, 0(s0)
# set current max to first item, indices to 0
add s3, x0, x0
add s4, x0, s2
add s5, x0, x0
j loop_continue
loop_continue:
addi s1, s1, -1 # subtract one from size
beq s1, x0, loop_end # end if done
addi s0, s0, 4 # go to next pointer
addi s3, s3, 1 # add one to current index
lw s2, 0(s0) # load value to current pointer
bge s2, s4, set_max # if s2 is greater than s4, set max index and max value
j loop_continue
exit:
addi a1, zero, 120
j exit2
loop_end:
mv a0, s5
# Epilogue
lw s5, 20(sp)
lw s4, 16(sp)
lw s3, 12(sp)
lw s2, 8(sp)
lw s1, 4(sp)
lw s0, 0(sp)
addi sp, sp, 32
ret
set_max:
# if the items are equal, move on - else, set max index & max val
beq s2, s4, loop_continue
add s4, s2, x0
add s5, s3, x0
j loop_continue
.globl classify
.text
classify:
# =====================================
# COMMAND LINE ARGUMENTS
# =====================================
# Args:
# a0 (int) argc
# a1 (char**) argv
# a2 (int) print_classification, if this is zero,
# you should print the classification. Otherwise,
# this function should not print ANYTHING.
# Returns:
# a0 (int) Classification
# Exceptions:
# - If there are an incorrect number of command line args,
# this function terminates the program with exit code 121.
# - If malloc fails, this function terminates the program with exit code 116 (though we will also accept exit code 122).
#
# Usage:
# main.s <M0_PATH> <M1_PATH> <INPUT_PATH> <OUTPUT_PATH>
# Exit if wrong amount of args
addi t0, x0, 5
bne a0, t0, error121
# Prologue
addi sp, sp, -64
sw s0, 0(sp) # argv
sw s1, 4(sp) # print_classification
sw s2, 8(sp) # m0 / matmuled matrix
sw s3, 12(sp) # m0 row
sw s4, 16(sp) # m0 column
sw s5, 20(sp) # input / m1
sw s6, 24(sp) # input / m1 row
sw s7, 28(sp) # input / m1 column
sw s8, 32(sp) # first as row pointer, second as layer one
sw s9, 36(sp) # first as column pointer, second as layers two and three
sw s10, 40(sp) # argmax return val & first pointer
sw s11, 44(sp) # pointer matrix
sw ra, 48(sp)
# Load inputs into save registers
mv s0, a1
mv s1, a2
# =====================================
# LOAD MATRICES
# =====================================
# Allocate space for row & col pointers
addi sp, sp, -8
sw s8, 0(sp)
sw s9, 4(sp)
# Load pretrained m0
lw t0, 4(s0)
mv a0, t0
mv a1, sp
addi t0, sp, 4
mv a2, t0
jal ra read_matrix
mv s2, a0
lw s3, 0(sp)
lw s4, 4(sp)
# Load input matrix
lw t0, 12(s0)
mv a0, t0
mv a1, sp
addi t0, sp, 4
mv a2, t0
jal ra read_matrix
mv s5, a0
lw s6, 0(sp)
lw s7, 4(sp)
# =====================================
# RUN LAYERS
# =====================================
# 1. LINEAR LAYER: m0 * input
# 2. NONLINEAR LAYER: ReLU(m0 * input)
# 3. LINEAR LAYER: m1 * ReLU(m0 * input)
# Allocate space for matrix m0 * input
addi t1, x0, 4
mul t0, s3, s7
mul a0, t1, t0
jal ra malloc
beq a0, x0, error116
mv s10, a0
# Linear layer: m0 * input
mv a0, s2
mv a1, s3
mv a2, s4
mv a3, s5
mv a4, s6
mv a5, s7
mv a6, s10
jal ra matmul
mv s2, s10
add s4, x0, s7
# Nonlinear layer: ReLU of (m0 * input)
mv a0, s2
mul a1, s3, s4
jal ra relu
mv s2, a0
# Load m1 into s5, s6, s7
lw t0, 8(s0)
mv a0, t0
mv a1, sp
addi t0, sp, 4
mv a2, t0
jal ra read_matrix
mv s5, a0
lw s6, 0(sp)
lw s7, 4(sp)
# Malloc for next linear layer
addi t1, x0, 4
mul t0, s6, s4
mul a0, t1, t0
jal ra malloc
beq a0, x0, error116
mv s11, a0
# Linear layer
mv a0, s5
mv a1, s6
mv a2, s7
mv a3, s2
mv a4, s3
mv a5, s4
mv a6, s11
jal ra matmul
mv s5, s11
add s7, s4, x0
# Increment stack pointer
lw s9, 4(sp)
lw s8, 0(sp)
addi sp, sp, 8
# Free s10
mv a0, s10
jal ra free
# =====================================
# WRITE OUTPUT
# =====================================
# Write output matrix
lw a0, 16(s0)
mv a1, s5
mv a2, s6
mv a3, s7
jal ra write_matrix
# =====================================
# CALCULATE CLASSIFICATION/LABEL
# =====================================
# Call argmax
mv a0, s5
mul a1, s6, s7
jal ra argmax
mv s10, a0
# Print classification
bne s1, x0, exit
add a1, x0, s10
jal ra print_int
# Print newline afterwards for clarity
addi a1, x0, 10
jal ra print_char
j exit
exit:
mv a0, s11
jal ra free
mv a0, s10
lw ra, 48(sp)
lw s11, 44(sp)
lw s10, 40(sp)
lw s9, 36(sp)
lw s8, 32(sp)
lw s7, 28(sp)
lw s6, 24(sp)
lw s5, 20(sp)
lw s4, 16(sp)
lw s3, 12(sp)
lw s2, 8(sp)
lw s1, 4(sp)
lw s0, 0(sp)
addi sp, sp, 64
ret
error116:
addi a1, x0, 116
j exit2
error121:
addi a1, x0, 121
j exit2
.globl dot
.text
# =======================================================
# FUNCTION: Dot product of 2 int vectors
# Arguments:
# a0 (int*) is the pointer to the start of v0
# a1 (int*) is the pointer to the start of v1
# a2 (int) is the length of the vectors
# a3 (int) is the stride of v0
# a4 (int) is the stride of v1
# Returns:
# a0 (int) is the dot product of v0 and v1
# Exceptions:
# - If the length of the vector is less than 1,
# this function terminates the program with error code 123.
# - If the stride of either vector is less than 1,
# this function terminates the program with error code 124.
# =======================================================
dot:
beq a2, x0, error123
beq a3, x0, error124
beq a4, x0, error124
# Prologue
addi sp, sp -72
sw s0, 0(sp) # s0: pointer to the start of v0
sw s1, 4(sp) # s1: pointer to the start of v1
sw s2, 8(sp) # s2: the length of the vectors
sw s3, 12(sp) # s3: the stride of v0
sw s4, 16(sp) # s4: the stride of v1
sw s5, 20(sp) # s5: current item of v0
sw s6, 24(sp) # s6: current item of v1
sw s7, 28(sp) # s7: the return value
j loop_start
loop_start:
# load everything into save registers
mv s0, a0
mv s1, a1
add s2, a2, x0
add s3, a3, x0
add s4, a4, x0
lw s5, 0(s0)
lw s6, 0(s1)
add s7, x0, x0
mul t2, s5, s6 # multiply current values and load them into t2
add s7, s7, t2 # add it to s7
j loop_continue
loop_continue:
addi s2, s2, -1 # subtract one from size
beq s2, x0, loop_end # end if done
addi t2, x0, 4
mul t0, s3, t2 # multiply bytes by stride
mul t1, s4, t2
add s0, s0, t0 # go to next v0 pointer
add s1, s1, t1 # go to next v1 pointer
lw s5, 0(s0) # load value to current pointer of v0
lw s6, 0(s1) # load value to current pointer of v0
mul t2, s5, s6 # multiply current values and load them into t2
add s7, s7, t2 # add it to s7
j loop_continue
loop_end:
mv a0, s7
# Epilogue
lw s7, 28(sp)
lw s6, 24(sp)
lw s5, 20(sp)
lw s4, 16(sp)
lw s3, 12(sp)
lw s2, 8(sp)
lw s1, 4(sp)
lw s0, 0(sp)
addi sp, sp, 72
ret
error123:
addi a1, zero, 123
j exit2
error124:
addi a1, zero, 124
j exit2
.import read_matrix.s
.import write_matrix.s
.import matmul.s
.import dot.s
.import relu.s
.import argmax.s
.import utils.s
.import classify.s
.globl main
# This is a dummy main function which imports and calls the classify function.
# While it just exits right after, it could always call classify again.
main:
# initialize register a2 to zero
mv a2, zero
# call classify function
jal classify
# exit program normally
jal exit
.globl matmul
.text
# =======================================================
# FUNCTION: Matrix Multiplication of 2 integer matrices
# d = matmul(m0, m1)
# Arguments:
# a0 (int*) is the pointer to the start of m0
# a1 (int) is the # of rows (height) of m0
# a2 (int) is the # of columns (width) of m0
# a3 (int*) is the pointer to the start of m1
# a4 (int) is the # of rows (height) of m1
# a5 (int) is the # of columns (width) of m1
# a6 (int*) is the pointer to the the start of d
# Returns:
# None (void), sets d = matmul(m0, m1)
# Exceptions:
# Make sure to check in top to bottom order!
# - If the dimensions of m0 do not make sense,
# this function terminates the program with exit code 125.
# - If the dimensions of m1 do not make sense,
# this function terminates the program with exit code 126.
# - If the dimensions of m0 and m1 don't match,
# this function terminates the program with exit code 127.
# =======================================================
matmul:
# Error checks
addi t0, x0, 1
blt a1, t0, error125 # if number of rows of m0 are negative or 0, error
blt a2, t0, error125 # if number of columns of m0 are negative or 0, error
blt a4, t0, error126 # if number of rows of m1 are negative or 0, error
blt a5, t0, error126 # if number of columns of m1 are negative or 0, error
bne a2, a4, error127 # if width of m0 != height of m1, error
# Prologue
addi sp, sp, -80
sw s0, 0(sp) # s0: pointer to the start of m0
sw s1, 4(sp) # s1: the # of rows (height) of m0
sw s2, 8(sp) # s2: the # of columns (width) of m0
sw s3, 12(sp) # s3: the pointer to the start of m1
sw s4, 16(sp) # s4: the # of rows (height) of m1
sw s5, 20(sp) # s5: is the # of columns (width) of m1
sw s6, 24(sp) # s6: the pointer to the the start of return value
sw s7, 28(sp) # s7: immutable m1
sw s8, 32(sp) # s8: immutable m1 width
sw ra, 36(sp) # ra
# load everything into save registers
mv s0, a0
mv s1, a1
mv s2, a2
mv s3, a3
mv s4, a4
mv s5, a5
mv s6, a6
mv s7, a3
mv s8, a5
add ra, x0, x0
j outer_loop_start
outer_loop_start:
# Check if completed rows [height = 0], if yes outer loop end
beq s1, x0, outer_loop_end
j inner_loop_start
inner_loop_start:
# Prepare for call to dot
# Check if out of columns [s5]
beq s5, x0, inner_loop_end # done, move on to multiplying next row
mv a0, s0
mv a1, s3
add a2, s2, x0 # Width of row
addi a3, x0, 1
add a4, x0, s8 # Stride by width of m1's row
jal ra dot
sw a0, 0(s6)
addi s5, s5, -1
addi s6, s6, 4 # Next item in rv
addi s3, s3, 4 # go to next column
j inner_loop_start
inner_loop_end:
mv s3, s7 # Changing s3 back to original m1
add s5, x0, s8 # Changing s5 back to original m1 width
addi s1, s1, -1 # Subtracting one from row
addi t1, x0, 4 # Offset m0 by first row
mul t0, s2, t1
add s0, s0, t0
j outer_loop_start
outer_loop_end:
# Epilogue
lw ra, 36(sp)
lw s8, 32(sp)
lw s7, 28(sp)
lw s6, 24(sp)
lw s5, 20(sp)
lw s4, 16(sp)
lw s3, 12(sp)
lw s2, 8(sp)
lw s1, 4(sp)
lw s0, 0(sp)
addi sp, sp, 80
ret
error125:
addi a1, zero, 125
j exit2
error126:
addi a1, zero, 126
j exit2
error127:
addi a1, zero, 127
j exit2
.globl read_matrix
.text
# ==============================================================================
# FUNCTION: Allocates memory and reads in a binary file as a matrix of integers
#
# FILE FORMAT:
# The first 8 bytes are two 4 byte ints representing the # of rows and columns
# in the matrix. Every 4 bytes afterwards is an element of the matrix in
# row-major order.
# Arguments:
# a0 (char*) is the pointer to string representing the filename
# a1 (int*) is a pointer to an integer, we will set it to the number of rows
# a2 (int*) is a pointer to an integer, we will set it to the number of columns
# Returns:
# a0 (int*) is the pointer to the matrix in memory
# Exceptions:
# - If malloc returns an error,
# this function terminates the program with error code 116.
# - If you receive an fopen error or eof,
# this function terminates the program with error code 117.
# - If you receive an fread error or eof,
# this function terminates the program with error code 118.
# - If you receive an fclose error or eof,
# this function terminates the program with error code 119.
# ==============================================================================
# plan:
# 1- open file, read it into s3
# 2 - read in row (s1) and column (s2) using fread
# 3- allocate space for matrix into s4, area: s5
# 4- fread straight into matrix
# 5- end & close
read_matrix:
# Prologue
addi sp, sp, -64
sw s0, 0(sp) # s0: the pointer to string representing the filename
sw s1, 4(sp) # s1: a pointer to an integer, number of rows
sw s2, 8(sp) # s2: a pointer to an integer, number of columns
sw s3, 12(sp) # s3: fopen read into
sw s4, 16(sp) # s4: matrix
sw s5, 20(sp) # s5: area of matrix
sw s6, 24(sp) # s6: -1
sw s7, 28(sp) # s7: 4
sw ra, 32(sp)
# Load into save registers
mv s0, a0
mv s1, a1
mv s2, a2
addi s6, x0, -1
addi s7, x0, 4
# Open file, read it into s3
mv a1, s0
add a2, x0, x0
jal ra fopen
beq a0, s6, error117
mv s3, a0
# Read in row to s1
mv a1, s3
mv a2, s1
addi a3, x0, 4
jal ra fread
bne a0, s7, error118
# Read in column to s2
mv a1, s3
mv a2, s2
addi a3, x0, 4
jal ra fread
bne a0, s7, error118
# Find area & bytes
lw t0, 0(s1)
lw t1, 0(s2)
mul s5, t0, t1
mul s5, s5, s7
# Allocate space for matrix
add a0, s5, x0
jal ra malloc
beq a0, x0, error116
mv s4, a0
# Read in matrix
mv a1, s3
mv a2, s4
add a3, s5, x0
jal ra fread
bne a0, s5, error118
# Close file
mv a1, s3
jal ra fclose
beq a0, s6, error119
# Load everything to proper registers
mv a0, s4
mv a1, s1
mv a2, s2
# Prologue
lw ra, 32(sp)
lw s7, 28(sp)
lw s6, 24(sp)
lw s5, 20(sp)
lw s4, 16(sp)
lw s3, 12(sp)
lw s2, 8(sp)
lw s1, 4(sp)
lw s0, 0(sp)
addi sp, sp, 64
ret
error116:
addi a1, zero, 116
j exit2
error117:
addi a1, zero, 117
j exit2
error118:
addi a1, zero, 118
j exit2
error119:
addi a1, zero, 119
j exit2
.globl relu
.text
# ==============================================================================
# FUNCTION: Performs an inplace element-wise ReLU on an array of ints
# Arguments:
# a0 (int*) is the pointer to the array
# a1 (int) is the # of elements in the array
# Returns:
# None
# Exceptions:
# - If the length of the vector is less than 1,
# this function terminates the program with error code 115.
# ==============================================================================
relu:
beq a1, x0, exit
# Prologue
addi sp, sp, -32
# s0: pointer to array, s1: current size, s2: current item
sw s0, 0(sp)
sw s1, 4(sp)
sw s2, 8(sp)
j loop_start
loop_start:
# load everything into save registers
mv s0, a0
add s1, a1, x0
lw s2, 0(s0)
bge s2, x0, loop_continue # branch if positive or 0
sw x0, 0(s0) # set to 0 if negative
j loop_continue
loop_continue:
addi s1, s1, -1 # subtract one
beq s1, x0, loop_end # end if done
addi s0, s0, 4 # go to next pointer
lw s2, 0(s0) # load value
bge s2, x0, loop_continue # branch if positive
sw x0, 0(s0) # set to 0 if negative
j loop_continue
exit:
addi a1, zero, 115
j exit2
loop_end:
# Epilogue
lw s2, 8(sp)
lw s1, 4(sp)
lw s0, 0(sp)
addi sp, sp, 32
ret
##############################################################
# Do not modify! (But feel free to use the functions provided)
##############################################################
#define c_print_int 1
#define c_print_str 4
#define c_atoi 5
#define c_sbrk 9
#define c_exit 10
#define c_print_char 11
#define c_openFile 13
#define c_readFile 14
#define c_writeFile 15
#define c_closeFile 16
#define c_exit2 17
#define c_fflush 18
#define c_feof 19
#define c_ferror 20
#define c_printHex 34
# ecall wrappers
.globl print_int, print_str, atoi, sbrk, exit, print_char, fopen, fread, fwrite, fclose, exit2, fflush, ferror, print_hex
# helper functions
.globl file_error, print_int_array, malloc, free, print_num_alloc_blocks, num_alloc_blocks
# unittest helper functions
.globl compare_int_array
.data
error_string: .string "This library file should not be directly called!"
.text
# Exits if you run this file
main:
la a1 error_string
jal print_str
li a1 1
jal exit2
# End main
#================================================================
# void print_int(int a1)
# Prints the integer in a1.
# args:
# a1 = integer to print
# return:
# void
#================================================================
print_int:
li a0 c_print_int
ecall
ret
#================================================================
# void print_str(char *a1)
# Prints the null-terminated string at address a1.
# args:
# a1 = address of the string you want printed.
# return:
# void
#================================================================
print_str:
li a0 c_print_str
ecall
ret
#================================================================
# int atoi(char* a1)
# Returns the integer version of the string at address a1.
# args:
# a1 = address of the string you want to turn into an integer.
# return:
# a0 = Integer representation of string
#================================================================
atoi:
li a0 c_atoi
ecall
ret
#================================================================
# void *sbrk(int a1)
# Allocates a1 bytes onto the heap.
# args:
# a1 = Number of bytes you want to allocate.
# return:
# a0 = Pointer to the start of the allocated memory
#================================================================
sbrk:
li a0 c_sbrk
ecall
ret
#================================================================
# void noreturn exit()
# Exits the program with a zero exit code.
# args:
# None
# return:
# No Return
#================================================================
exit:
li a0 c_exit
ecall
#================================================================
# void print_char(char a1)
# Prints the ASCII character in a1 to the console.
# args:
# a1 = character to print
# return:
# void
#================================================================
print_char:
li a0 c_print_char
ecall
ret
#================================================================
# int fopen(char *a1, int a2)
# Opens file with name a1 with permissions a2.
# args:
# a1 = filepath
# a2 = permissions (0, 1, 2, 3, 4, 5 = r, w, a, r+, w+, a+)
# return:
# a0 = file descriptor
#================================================================
fopen:
li a0 c_openFile
ecall
#FOPEN_RETURN_HOOK
ret
#================================================================
# int fread(int a1, void *a2, size_t a3)
# Reads a3 bytes of the file into the buffer a2.
# args:
# a1 = file descriptor
# a2 = pointer to the buffer you want to write the read bytes to.
# a3 = Number of bytes to be read.
# return:
# a0 = Number of bytes actually read.
#================================================================
fread:
li a0 c_readFile
ecall
#FREAD_RETURN_HOOK
ret
#================================================================
# int fwrite(int a1, void *a2, size_t a3, size_t a4)
# Writes a3 * a4 bytes from the buffer in a2 to the file descriptor a1.
# args:
# a1 = file descriptor
# a2 = Buffer to read from
# a3 = Number of items to read from the buffer.
# a4 = Size of each item in the buffer.
# return:
# a0 = Number of elements writen. If this is less than a3,
# it is either an error or EOF. You will also need to still flush the fd.
#================================================================
fwrite:
li a0 c_writeFile
ecall
#FWRITE_RETURN_HOOK
ret
#================================================================
# int fclose(int a1)
# Closes the file descriptor a1.
# args:
# a1 = file descriptor
# return:
# a0 = 0 on success, and EOF (-1) otherwise.
#================================================================
fclose:
li a0 c_closeFile
ecall
#FCLOSE_RETURN_HOOK
ret
#================================================================
# void noreturn exit2(int a1)
# Exits the program with error code a1.
# args:
# a1 = Exit code.
# return:
# This program does not return.
#================================================================
exit2:
li a0 c_exit2
ecall
ret
#================================================================
# int fflush(int a1)
# Flushes the data to the filesystem.
# args:
# a1 = file descriptor
# return:
# a0 = 0 on success, and EOF (-1) otherwise.
#================================================================
fflush:
li a0 c_fflush
ecall
ret
#================================================================
# int ferror(int a1)
# Returns a nonzero value if the file stream has errors, otherwise it returns 0.
# args:
# a1 = file descriptor
# return:
# a0 = Nonzero falue if the end of file is reached. 0 Otherwise.
#================================================================
ferror:
li a0 c_ferror
ecall
ret
#================================================================
# void print_hex(int a1)
#
# args:
# a1 = The word which will be printed as a hex value.
# return:
# void
#================================================================
print_hex:
li a0 c_printHex
ecall
ret
#================================================================
# void* malloc(int a0)
# Allocates heap memory and return a pointer to it
# args:
# a0 is the # of bytes to allocate heap memory for
# return:
# a0 is the pointer to the allocated heap memory
#================================================================
malloc:
# Call to sbrk
mv a1 a0
li a0 0x3CC
addi a6 x0 1
ecall
#MALLOC_RETURN_HOOK
ret
#================================================================
# void free(int a0)
# Frees heap memory referenced by pointer
# args:
# a0 is the pointer to heap memory to free
# return:
# void
#================================================================
free:
mv a1 a0
li a0 0x3CC
addi a6 x0 4
ecall
ret
#================================================================
# void num_alloc_blocks(int a0)
# Returns the number of currently allocated blocks
# args:
# void
# return:
# a0 is the # of allocated blocks
#================================================================
num_alloc_blocks:
li a0, 0x3CC
li a6, 5
ecall
ret
print_num_alloc_blocks:
addi sp, sp -4
sw ra 0(sp)
jal num_alloc_blocks
mv a1 a0
jal print_int
li a1 '\n'
jal print_char
lw ra 0(sp)
addi sp, sp 4
ret
#================================================================
# void print_int_array(int* a0, int a1, int a2)
# Prints an integer array, with spaces between the elements
# args:
# a0 is the pointer to the start of the array
# a1 is the # of rows in the array
# a2 is the # of columns in the array
# return:
# void
#================================================================
print_int_array:
# Prologue
addi sp sp -24
sw s0 0(sp)
sw s1 4(sp)
sw s2 8(sp)
sw s3 12(sp)
sw s4 16(sp)
sw ra 20(sp)
# Save arguments
mv s0 a0
mv s1 a1
mv s2 a2
# Set outer loop index
li s3 0
outer_loop_start:
# Check outer loop condition
beq s3 s1 outer_loop_end
# Set inner loop index
li s4 0
inner_loop_start:
# Check inner loop condition
beq s4 s2 inner_loop_end
# t0 = row index * len(row) + column index
mul t0 s2 s3
add t0 t0 s4
slli t0 t0 2
# Load matrix element
add t0 t0 s0
lw t1 0(t0)
# Print matrix element
mv a1 t1
jal print_int
# Print whitespace
li a1 ' '
jal print_char
addi s4 s4 1
j inner_loop_start
inner_loop_end:
# Print newline
li a1 '\n'
jal print_char
addi s3 s3 1
j outer_loop_start
outer_loop_end:
# Epilogue
lw s0 0(sp)
lw s1 4(sp)
lw s2 8(sp)
lw s3 12(sp)
lw s4 16(sp)
lw ra 20(sp)
addi sp sp 24
ret
#================================================================
# void compare_int_array(int a0, int* a0, int* a1, int a2)
# Prints an integer array, with spaces between the elements
# args:
# a0 is the base exit code that will be used if an unequal element is found
# a1 is the pointer to the expected data
# a2 is the pointer to the actual data
# a3 is the number of elements in each array
# a4 is the error message
# return:
# void
#================================================================
compare_int_array:
# Prologue
addi sp sp -24
sw s0 0(sp)
sw s1 4(sp)
sw s2 8(sp)
sw s3 12(sp)
sw s4 16(sp)
sw ra 20(sp)
# save pointer to original array in s1
mv s1, a2
# t0: current element
mv t0 zero
loop_start:
# we are done once t0 >= a3
bge t0, a3, end
# t1 := *a1
lw t1, 0(a1)
# t2 := *a2
lw t2, 0(a2)
# if the values are different -> fail
bne t1, t2, fail
# go to next value
addi t0, t0, 1
addi a1, a1, 4
addi a2, a2, 4
j loop_start
fail:
# exit code: a0
mv s0, a0
# remember length
mv s2, a3
# print user supplied error message
mv a1, a4
jal print_str
# print actual data
mv a0, s1
li a1, 1
mv a2, s2
jal print_int_array
# exit with user defined error code
mv a1, s0
jal exit2
end:
# Epilogue
lw s0 0(sp)
lw s1 4(sp)
lw s2 8(sp)
lw s3 12(sp)
lw s4 16(sp)
lw ra 20(sp)
addi sp sp 24
ret
.globl write_matrix
.text
# ==============================================================================
# FUNCTION: Writes a matrix of integers into a binary file
# FILE FORMAT:
# The first 8 bytes of the file will be two 4 byte ints representing the
# numbers of rows and columns respectively. Every 4 bytes thereafter is an
# element of the matrix in row-major order.
# Arguments:
# a0 (char*) is the pointer to string representing the filename
# a1 (int*) is the pointer to the start of the matrix in memory
# a2 (int) is the number of rows in the matrix
# a3 (int) is the number of columns in the matrix
# Returns:
# None
# Exceptions:
# - If you receive an fopen error or eof,
# this function terminates the program with error code 112.
# - If you receive an fwrite error or eof,
# this function terminates the program with error code 113.
# - If you receive an fclose error or eof,
# this function terminates the program with error code 114.
# ==============================================================================
# plan:
# 1 - open file
# 2 - decrement to load row and column
# 3 - write row and column
# 4 - write matrix into file
write_matrix:
# Prologue
addi sp, sp, -62
sw s0, 0(sp) # s0: the pointer to string representing the filename
sw s1, 4(sp) # s1: the pointer to the start of the matrix in memory
sw s2, 8(sp) # s2: the number of rows in the matrix
sw s3, 12(sp) # s3: the number of columns in the matrix
sw s4, 16(sp) # s4: area of matrix
sw s5, 20(sp) # s5: fopen save register
sw s6, 24(sp) # s6: -1
sw s7, 28(sp) # s7: 4
sw s8, 32(sp) # s8: 1
sw ra, 36(sp)
# Load everything into save registers
mv s0, a0
mv s1, a1
mv s2, a2
mv s3, a3
mul s4, s2, s3
addi s6, x0, -1
addi s7, x0, 4
addi s8, x0, 2
# Open file
mv a1, s0
addi a2, x0, 1
jal ra fopen
beq a0, s6, error112
mv s5, a0
# Write row & column to file
addi sp, sp, -8
sw s2, 0(sp)
sw s3, 4(sp)
mv a1, s5
mv a2, sp
addi a3, x0, 2
addi a4, x0, 4
jal ra fwrite
bne a0, s8, error113
addi sp, sp, 8
# Write matrix to file
mv a1, s5
mv a2, s1
add a3, s4, x0
addi a4, x0, 4
jal ra fwrite
bne a0, s4, error113
# End
mv a1, s5
jal ra fclose
beq a0, s6, error114
# Epilogue
lw ra, 36(sp)
lw s8, 32(sp)
lw s7, 28(sp)
lw s6, 24(sp)
lw s5, 20(sp)
lw s4, 16(sp)
lw s3, 12(sp)
lw s2, 8(sp)
lw s1, 4(sp)
lw s0, 0(sp)
addi sp, sp, 62
ret
error112:
addi a1, zero, 112
j exit2
error113:
addi a1, zero, 113
j exit2
error114:
addi a1, zero, 114
j exit2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment