Last active
December 15, 2015 13:08
-
-
Save OndrejSlamecka/5264667 to your computer and use it in GitHub Desktop.
Solves system of linear equations in field Z_p, where p is prime. In case of many solutions finds the lexicografically smallest solution (smallest x1, x2…). With -e parameter prints number of solutions up to 10^12. A complete solution to http://cecko.eu/public/lineq_2013 [cs]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
/* --- IO --- */ | |
/** | |
* @brief Reads line of text, reallocates memory if needed | |
* @param input | |
* @param size | |
* @return 1 in case of success, 0 in case memory reallocation failed | |
* TODO: Modify size | |
*/ | |
int readSpacelessLine(char *input, size_t size) | |
{ | |
int c; | |
unsigned i = 0; | |
size_t sizeConst = size; | |
while((c = getchar()) && (c != '\n')) { | |
if (c != ' ') { | |
input[i++] = c; | |
if (i == size && c != '\n') { | |
size += sizeConst; | |
input = (char *) realloc(input, size); | |
if (input == NULL) { | |
return 0; | |
} | |
} | |
} | |
} | |
input[i] = '\0'; | |
return 1; | |
} | |
/** | |
* @brief Helper function for development | |
* @param matrix | |
* @param nEquations | |
* @param nVariables | |
*/ | |
void printMatrix(int nEquations, int nVariables, long int A[nEquations][nVariables + 1]) | |
{ | |
int i, j, cols = nVariables + 1; | |
for (i = 0; i < nEquations; i++) { | |
for (j = 0; j < cols; j++) | |
printf("%5ld ", A[i][j]); | |
printf("\n"); | |
} | |
} | |
void printSolution(long int solution[], int nVariables) | |
{ | |
for (int i = 1; i <= nVariables; i++) { | |
printf("x%d = %ld\n", i, solution[i - 1]); // -1 because x1 lies at position 0 | |
} | |
} | |
/* --- Maths --- */ | |
int isPrime(long int n) | |
{ | |
if (n == 2) { | |
return 1; | |
} | |
if (n % 2 == 0 || n == 1) { | |
return 0; | |
} | |
for (int i = 3; i * i <= n; i += 2) { | |
if (n % i == 0) { | |
return 0; | |
} | |
} | |
return 1; | |
} | |
/** | |
* @brief Enhancement of %, always returns positive values | |
* @param n | |
* @param modulus | |
*/ | |
int mod(long int n, long int modulus) | |
{ | |
return (n % modulus + modulus) % modulus; | |
} | |
/** | |
* @brief Returns modular multiplicative inverse | |
* @param n | |
* @param modulus | |
*/ | |
int modInverse(long int number, long int modulus) | |
{ | |
// Extended Euclid's algorithm | |
long int a = number, b = modulus; | |
long int x = 1, y = 0, | |
xLast = 0, yLast = 1, | |
q = 0, r = 0, m = 0, n = 0; // quotient, remainder, temp values | |
while(a != 0) { | |
q = b / a; | |
r = b % a; | |
m = xLast - q * x; | |
n = yLast - q * y; | |
xLast = x, yLast = y; | |
x = m, y = n; | |
b = a, a = r; | |
} | |
return mod(xLast, modulus); | |
} | |
long int moduloDivision(long int b, long int a, long int p) | |
{ | |
return mod(b * modInverse(a, p), p); | |
} | |
void initMatrix(int rows, int nVariables, long int matrix[rows][nVariables + 1]) | |
{ | |
int i, j, cols = nVariables + 1; | |
for (i = 0; i < rows; i++) { | |
for (j = 0; j < cols; j++) { | |
matrix[i][j] = 0; | |
} | |
} | |
} | |
/** | |
* @brief Eliminates matrix and returns number of solutions | |
* @param rows | |
* @param nVariables | |
* @param matrix | |
* @param modulus | |
* @param nEquations | |
* @param 1 for exact number of solutions (n SOLUTIONS), 0 otherwise | |
* @return Number of solutions or -1 when it is over 10^12 | |
*/ | |
long long int eliminateMatrix(int rows, int nVariables, long int A[rows][nVariables + 1], long int modulus, int *nEquations, int computeNumberOfSolutions) | |
{ | |
// ELIMINATION | |
int y, x, i, j, cols = nVariables + 1, nPivots = 0, pivotPosition = rows, nonZeroValueRow; | |
long int inverse, multiplier, swap; | |
for (x = nVariables - 1; x >= 0; x--) { | |
// If all pivots were found, skip | |
if (nPivots > rows) { | |
break; | |
} | |
pivotPosition = rows - nPivots - 1; | |
// If pivot is 0 find a non-zero value in this column and swap lines | |
if (A[pivotPosition][x] == 0) { | |
nonZeroValueRow = -1; // row where value in column x is not zero | |
for (y = pivotPosition; y >= 0; y--) { | |
if (A[y][x] != 0) { | |
nonZeroValueRow = y; | |
break; // found, no need to do more | |
} | |
} | |
if (nonZeroValueRow == -1) { // no pivot in this column | |
continue; | |
} else { // new pivot found, swap rows | |
for (i = 0; i < cols; i++) { | |
swap = A[pivotPosition][i]; | |
A[pivotPosition][i] = A[nonZeroValueRow][i]; | |
A[nonZeroValueRow][i] = swap; | |
} | |
} | |
} | |
// Divide to leading 1 in pivot row | |
inverse = modInverse(A[pivotPosition][x], modulus); | |
for (j = 0; j < cols; j++) { | |
A[pivotPosition][j] = ((A[pivotPosition][j] * inverse) + modulus) % modulus; | |
} | |
// Reduce rows above | |
for (i = rows - nPivots - 2; i >= 0; i--) { | |
multiplier = A[i][x]; | |
for (j = 0; j < cols; j++) { | |
A[i][j] = mod(A[i][j] - A[pivotPosition][j] * multiplier, modulus); | |
} | |
} | |
nPivots++; | |
} | |
// CALCULATE NUMBER OF SOLUTIONS | |
int emptyRow; | |
*nEquations = rows; | |
for (y = 0; y < rows; y++) { | |
emptyRow = 1; | |
for (x = 0; x < nVariables; x++) { | |
if (A[y][x] != 0) { | |
emptyRow = 0; | |
break; // found, no need to do more | |
} | |
} | |
if (emptyRow) { | |
*nEquations -= 1; // -- causes compiler error | |
if (A[y][nVariables] != 0) { // matrix like (0 0 0 | 1) has no solution | |
return 0; | |
} | |
} | |
} | |
if (*nEquations == nVariables) { | |
return 1; | |
} else { | |
if (!computeNumberOfSolutions) { | |
return 2; | |
} else { | |
// n of solutions = modulus ^ free vars in matrix | |
int exp = nVariables - *nEquations; | |
long long int product = 1; | |
for (i = 0; i < exp; i++) { | |
product *= modulus; | |
if (product > 1000000000000LL) { | |
return -1; | |
} | |
} | |
return product; | |
} | |
} | |
} | |
void solveReducedMatrix(int nRows, int nEquations, int nVariables, long int A[nRows][nVariables + 1], long int solution[], long int modulus) | |
{ | |
int resolved[nVariables]; | |
int i, y, x, cols = nVariables + 1; | |
for (i = 0; i < cols; i++) { | |
solution[i] = 0; | |
resolved[i] = 0; | |
} | |
for (y = nRows - nEquations; y < nRows; y++) { | |
int pivot = -1; | |
long int value = A[y][nVariables]; | |
for (x = nVariables - 1; x >= 0; x--) { | |
if (A[y][x] != 0 && pivot == -1 && !resolved[x]) { | |
pivot = x; | |
resolved[x] = 1; | |
} else { | |
value = mod(value - A[y][x] * solution[x], modulus); | |
} | |
} | |
if (pivot != -1) { // pivot equals to -1 in case this column is full of zeros | |
solution[pivot] = moduloDivision(value, A[y][pivot], modulus); | |
} | |
} | |
} | |
/* --- Parsing --- */ | |
void parseTerm(char *input, int *index, long int *coef) | |
{ | |
if (*input == 'x') { | |
*coef = 1; | |
*index = strtol(strtok(input, "x"), (char **) NULL, 10); | |
} else { | |
*coef = strtol(strtok(input, "x"), (char **) NULL, 10); | |
*index = strtol(strtok(NULL, "x"), (char **) NULL, 10); | |
} | |
} | |
void parsePolynomial(char *input, long int row[], int nVariables, int modulus) | |
{ | |
int varIndex, i; | |
long int coef; | |
char *strpos = input, chunk[50], c; | |
c = *strpos; | |
for (i = 0; c != '\0'; strpos++) { | |
c = *strpos; | |
if (c != '+' && c != '\0') { | |
chunk[i] = c; | |
i++; | |
} else { | |
chunk[i] = '\0'; | |
i = 0; | |
if (strchr(chunk, 'x') == NULL) { | |
row[nVariables] = (row[nVariables] + strtol(chunk, (char **) NULL, 10)) % modulus; | |
} else { | |
parseTerm(chunk, &varIndex, &coef); | |
row[varIndex - 1] = (row[varIndex - 1] + coef) % modulus; | |
} | |
} | |
} | |
} | |
void subtractPolynomials(long int pol1[], long int pol2[], int nVariables, int modulus) | |
{ | |
for (int i = 0; i < nVariables; i++) | |
pol1[i] = ((pol1[i] - pol2[i]) % modulus + modulus) % modulus; | |
pol1[nVariables] = mod(-pol1[nVariables], modulus); | |
pol1[nVariables] = mod(pol2[nVariables] + pol1[nVariables], modulus); | |
} | |
void parseEquation(char *input, long int row[], int nVariables, int modulus) | |
{ | |
char *left = strtok(input, "="), | |
*right = strtok(NULL, "="); | |
long int aRight[nVariables + 1]; | |
for (int i = 0; i < nVariables + 1; i++) { | |
aRight[i] = 0; | |
} | |
parsePolynomial(left, row, nVariables, modulus); | |
parsePolynomial(right, aRight, nVariables, modulus); | |
subtractPolynomials(row, aRight, nVariables, modulus); | |
} | |
int main(int argc, char** argv) | |
{ | |
char buffer[16]; | |
// Handle switches | |
int computeNumberOfSolutions = 0; | |
if (argc == 2 && !strcmp(argv[1], "-e")) { | |
computeNumberOfSolutions = 1; | |
} | |
// Scan prime | |
long int prime; | |
fgets(buffer, 16, stdin); | |
sscanf(buffer, "%ld", &prime); | |
if (!isPrime(prime)) { | |
printf("NOT A PRIME\n"); | |
return 0; | |
} | |
// Cycle untill -1 -1 is entered | |
char *input = (char *) malloc(1000); | |
int nVariables, nEquations, nRows; | |
long long int nSolutions; | |
while(1) { | |
// Scan number of variables and equations | |
fgets(buffer, 16, stdin); | |
sscanf(buffer, "%d %d", &nVariables, &nRows); | |
if (nVariables == -1 && nRows == -1) { | |
break; | |
} | |
long int matrix[nRows][nVariables + 1]; | |
initMatrix(nRows, nVariables, matrix); | |
long int solution[nVariables]; | |
// Scan equations | |
for (int i = 0; i < nRows; i++) { | |
readSpacelessLine(input, 1000); | |
parseEquation(input, matrix[i], nVariables, prime); | |
} | |
// Solve system | |
nSolutions = eliminateMatrix(nRows, nVariables, matrix, prime, &nEquations, computeNumberOfSolutions); | |
if (nSolutions == 0) { | |
printf("NO SOLUTION\n"); | |
} else { | |
solveReducedMatrix(nRows, nEquations, nVariables, matrix, solution, prime); | |
if (nSolutions == 1) { | |
printf("ONE SOLUTION\n"); | |
} else { | |
if (computeNumberOfSolutions == 1) { | |
if (nSolutions == -1) { | |
printf("TOO MANY SOLUTIONS\n"); | |
} else { | |
printf("%llu SOLUTIONS\n", nSolutions); | |
} | |
} else { | |
printf("MANY SOLUTIONS\n"); | |
} | |
} | |
printSolution(solution, nVariables); | |
} | |
} // end while | |
if (input) { | |
free(input); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment