Skip to content

Instantly share code, notes, and snippets.

@OndrejSlamecka
Last active December 15, 2015 13:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save OndrejSlamecka/5264667 to your computer and use it in GitHub Desktop.
Save OndrejSlamecka/5264667 to your computer and use it in GitHub Desktop.
Solves system of linear equations in field Z_p, where p is prime. In case of many solutions finds the lexicografically smallest solution (smallest x1, x2…). With -e parameter prints number of solutions up to 10^12. A complete solution to http://cecko.eu/public/lineq_2013 [cs]
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
/* --- IO --- */
/**
* @brief Reads line of text, reallocates memory if needed
* @param input
* @param size
* @return 1 in case of success, 0 in case memory reallocation failed
* TODO: Modify size
*/
int readSpacelessLine(char *input, size_t size)
{
int c;
unsigned i = 0;
size_t sizeConst = size;
while((c = getchar()) && (c != '\n')) {
if (c != ' ') {
input[i++] = c;
if (i == size && c != '\n') {
size += sizeConst;
input = (char *) realloc(input, size);
if (input == NULL) {
return 0;
}
}
}
}
input[i] = '\0';
return 1;
}
/**
* @brief Helper function for development
* @param matrix
* @param nEquations
* @param nVariables
*/
void printMatrix(int nEquations, int nVariables, long int A[nEquations][nVariables + 1])
{
int i, j, cols = nVariables + 1;
for (i = 0; i < nEquations; i++) {
for (j = 0; j < cols; j++)
printf("%5ld ", A[i][j]);
printf("\n");
}
}
void printSolution(long int solution[], int nVariables)
{
for (int i = 1; i <= nVariables; i++) {
printf("x%d = %ld\n", i, solution[i - 1]); // -1 because x1 lies at position 0
}
}
/* --- Maths --- */
int isPrime(long int n)
{
if (n == 2) {
return 1;
}
if (n % 2 == 0 || n == 1) {
return 0;
}
for (int i = 3; i * i <= n; i += 2) {
if (n % i == 0) {
return 0;
}
}
return 1;
}
/**
* @brief Enhancement of %, always returns positive values
* @param n
* @param modulus
*/
int mod(long int n, long int modulus)
{
return (n % modulus + modulus) % modulus;
}
/**
* @brief Returns modular multiplicative inverse
* @param n
* @param modulus
*/
int modInverse(long int number, long int modulus)
{
// Extended Euclid's algorithm
long int a = number, b = modulus;
long int x = 1, y = 0,
xLast = 0, yLast = 1,
q = 0, r = 0, m = 0, n = 0; // quotient, remainder, temp values
while(a != 0) {
q = b / a;
r = b % a;
m = xLast - q * x;
n = yLast - q * y;
xLast = x, yLast = y;
x = m, y = n;
b = a, a = r;
}
return mod(xLast, modulus);
}
long int moduloDivision(long int b, long int a, long int p)
{
return mod(b * modInverse(a, p), p);
}
void initMatrix(int rows, int nVariables, long int matrix[rows][nVariables + 1])
{
int i, j, cols = nVariables + 1;
for (i = 0; i < rows; i++) {
for (j = 0; j < cols; j++) {
matrix[i][j] = 0;
}
}
}
/**
* @brief Eliminates matrix and returns number of solutions
* @param rows
* @param nVariables
* @param matrix
* @param modulus
* @param nEquations
* @param 1 for exact number of solutions (n SOLUTIONS), 0 otherwise
* @return Number of solutions or -1 when it is over 10^12
*/
long long int eliminateMatrix(int rows, int nVariables, long int A[rows][nVariables + 1], long int modulus, int *nEquations, int computeNumberOfSolutions)
{
// ELIMINATION
int y, x, i, j, cols = nVariables + 1, nPivots = 0, pivotPosition = rows, nonZeroValueRow;
long int inverse, multiplier, swap;
for (x = nVariables - 1; x >= 0; x--) {
// If all pivots were found, skip
if (nPivots > rows) {
break;
}
pivotPosition = rows - nPivots - 1;
// If pivot is 0 find a non-zero value in this column and swap lines
if (A[pivotPosition][x] == 0) {
nonZeroValueRow = -1; // row where value in column x is not zero
for (y = pivotPosition; y >= 0; y--) {
if (A[y][x] != 0) {
nonZeroValueRow = y;
break; // found, no need to do more
}
}
if (nonZeroValueRow == -1) { // no pivot in this column
continue;
} else { // new pivot found, swap rows
for (i = 0; i < cols; i++) {
swap = A[pivotPosition][i];
A[pivotPosition][i] = A[nonZeroValueRow][i];
A[nonZeroValueRow][i] = swap;
}
}
}
// Divide to leading 1 in pivot row
inverse = modInverse(A[pivotPosition][x], modulus);
for (j = 0; j < cols; j++) {
A[pivotPosition][j] = ((A[pivotPosition][j] * inverse) + modulus) % modulus;
}
// Reduce rows above
for (i = rows - nPivots - 2; i >= 0; i--) {
multiplier = A[i][x];
for (j = 0; j < cols; j++) {
A[i][j] = mod(A[i][j] - A[pivotPosition][j] * multiplier, modulus);
}
}
nPivots++;
}
// CALCULATE NUMBER OF SOLUTIONS
int emptyRow;
*nEquations = rows;
for (y = 0; y < rows; y++) {
emptyRow = 1;
for (x = 0; x < nVariables; x++) {
if (A[y][x] != 0) {
emptyRow = 0;
break; // found, no need to do more
}
}
if (emptyRow) {
*nEquations -= 1; // -- causes compiler error
if (A[y][nVariables] != 0) { // matrix like (0 0 0 | 1) has no solution
return 0;
}
}
}
if (*nEquations == nVariables) {
return 1;
} else {
if (!computeNumberOfSolutions) {
return 2;
} else {
// n of solutions = modulus ^ free vars in matrix
int exp = nVariables - *nEquations;
long long int product = 1;
for (i = 0; i < exp; i++) {
product *= modulus;
if (product > 1000000000000LL) {
return -1;
}
}
return product;
}
}
}
void solveReducedMatrix(int nRows, int nEquations, int nVariables, long int A[nRows][nVariables + 1], long int solution[], long int modulus)
{
int resolved[nVariables];
int i, y, x, cols = nVariables + 1;
for (i = 0; i < cols; i++) {
solution[i] = 0;
resolved[i] = 0;
}
for (y = nRows - nEquations; y < nRows; y++) {
int pivot = -1;
long int value = A[y][nVariables];
for (x = nVariables - 1; x >= 0; x--) {
if (A[y][x] != 0 && pivot == -1 && !resolved[x]) {
pivot = x;
resolved[x] = 1;
} else {
value = mod(value - A[y][x] * solution[x], modulus);
}
}
if (pivot != -1) { // pivot equals to -1 in case this column is full of zeros
solution[pivot] = moduloDivision(value, A[y][pivot], modulus);
}
}
}
/* --- Parsing --- */
void parseTerm(char *input, int *index, long int *coef)
{
if (*input == 'x') {
*coef = 1;
*index = strtol(strtok(input, "x"), (char **) NULL, 10);
} else {
*coef = strtol(strtok(input, "x"), (char **) NULL, 10);
*index = strtol(strtok(NULL, "x"), (char **) NULL, 10);
}
}
void parsePolynomial(char *input, long int row[], int nVariables, int modulus)
{
int varIndex, i;
long int coef;
char *strpos = input, chunk[50], c;
c = *strpos;
for (i = 0; c != '\0'; strpos++) {
c = *strpos;
if (c != '+' && c != '\0') {
chunk[i] = c;
i++;
} else {
chunk[i] = '\0';
i = 0;
if (strchr(chunk, 'x') == NULL) {
row[nVariables] = (row[nVariables] + strtol(chunk, (char **) NULL, 10)) % modulus;
} else {
parseTerm(chunk, &varIndex, &coef);
row[varIndex - 1] = (row[varIndex - 1] + coef) % modulus;
}
}
}
}
void subtractPolynomials(long int pol1[], long int pol2[], int nVariables, int modulus)
{
for (int i = 0; i < nVariables; i++)
pol1[i] = ((pol1[i] - pol2[i]) % modulus + modulus) % modulus;
pol1[nVariables] = mod(-pol1[nVariables], modulus);
pol1[nVariables] = mod(pol2[nVariables] + pol1[nVariables], modulus);
}
void parseEquation(char *input, long int row[], int nVariables, int modulus)
{
char *left = strtok(input, "="),
*right = strtok(NULL, "=");
long int aRight[nVariables + 1];
for (int i = 0; i < nVariables + 1; i++) {
aRight[i] = 0;
}
parsePolynomial(left, row, nVariables, modulus);
parsePolynomial(right, aRight, nVariables, modulus);
subtractPolynomials(row, aRight, nVariables, modulus);
}
int main(int argc, char** argv)
{
char buffer[16];
// Handle switches
int computeNumberOfSolutions = 0;
if (argc == 2 && !strcmp(argv[1], "-e")) {
computeNumberOfSolutions = 1;
}
// Scan prime
long int prime;
fgets(buffer, 16, stdin);
sscanf(buffer, "%ld", &prime);
if (!isPrime(prime)) {
printf("NOT A PRIME\n");
return 0;
}
// Cycle untill -1 -1 is entered
char *input = (char *) malloc(1000);
int nVariables, nEquations, nRows;
long long int nSolutions;
while(1) {
// Scan number of variables and equations
fgets(buffer, 16, stdin);
sscanf(buffer, "%d %d", &nVariables, &nRows);
if (nVariables == -1 && nRows == -1) {
break;
}
long int matrix[nRows][nVariables + 1];
initMatrix(nRows, nVariables, matrix);
long int solution[nVariables];
// Scan equations
for (int i = 0; i < nRows; i++) {
readSpacelessLine(input, 1000);
parseEquation(input, matrix[i], nVariables, prime);
}
// Solve system
nSolutions = eliminateMatrix(nRows, nVariables, matrix, prime, &nEquations, computeNumberOfSolutions);
if (nSolutions == 0) {
printf("NO SOLUTION\n");
} else {
solveReducedMatrix(nRows, nEquations, nVariables, matrix, solution, prime);
if (nSolutions == 1) {
printf("ONE SOLUTION\n");
} else {
if (computeNumberOfSolutions == 1) {
if (nSolutions == -1) {
printf("TOO MANY SOLUTIONS\n");
} else {
printf("%llu SOLUTIONS\n", nSolutions);
}
} else {
printf("MANY SOLUTIONS\n");
}
}
printSolution(solution, nVariables);
}
} // end while
if (input) {
free(input);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment