Skip to content

Instantly share code, notes, and snippets.

Last active July 10, 2023 03:26
Show Gist options
  • Save doutv/7d66968efc2c348405ad00349d34878c to your computer and use it in GitHub Desktop.
Save doutv/7d66968efc2c348405ad00349d34878c to your computer and use it in GitHub Desktop.
zk-mnist circom
pragma circom 2.0.1;
include "";
include "";
template ArgMax (n) {
signal input in[n];
signal output out;
component gts[n]; // store comparators
component switchers[n+1]; // switcher for comparing maxs
component aswitchers[n+1]; // switcher for arg max
signal maxs[n+1];
signal amaxs[n+1];
maxs[0] <== in[0];
amaxs[0] <== 0;
for(var i = 0; i < n; i++) {
gts[i] = GreaterThan(30);
switchers[i+1] = Switcher();
aswitchers[i+1] = Switcher();
gts[i].in[1] <== maxs[i];
gts[i].in[0] <== in[i];
switchers[i+1].sel <== gts[i].out;
switchers[i+1].L <== maxs[i];
switchers[i+1].R <== in[i];
aswitchers[i+1].sel <== gts[i].out;
aswitchers[i+1].L <== amaxs[i];
aswitchers[i+1].R <== i;
amaxs[i+1] <== aswitchers[i+1].outL;
maxs[i+1] <== switchers[i+1].outL;
out <== amaxs[n];
// image is non-negative length 84 vector, output from prior NN layers running in the web frontend
// A is final fully-connected layer's weights, 10x84 shape matrix (n=84)
// B is final bias, length 10 vector
template DigitReader (b, n) {
signal input image[b][n]; // must be non-negative
//signal output digit;
signal output digits[b];
var ndigits = 10;
// copy the values of A and B from the snarklayer.tsx file
var A2[ndigits][n] =
var Bb[ndigits] = [ -57068,33988,72676,115488,-84896,-15568,-94943,-115469,103303,-90709];
var A[ndigits][n] = [
var B[ndigits] = [
signal s[b][ndigits][n+1];
//component am = ArgMax(ndigits);
component ams[b];
for(var idx=0; idx < b; idx++) {
ams[idx] = ArgMax(ndigits);
for(var i=0; i<ndigits; i++){
s[idx][i][0] <== 0;
for(var j=1; j<=n; j++){
s[idx][i][j] <== s[idx][i][j-1] + A[i][j-1]*image[idx][j-1];
//[i] <== s[idx][i][n]+B[i] + 10000000;
ams[idx].in[i] <== s[idx][i][n]+B[i] + 1000000000;
if (idx == 0) {
log(s[idx][i][n] + B[i] + 1000000000);
ams[idx].out ==> digits[idx];
if (idx == 0) {
//am.out ==> digit;
//am.out ==> digit;
component main = DigitReader(16, 84);
/* INPUT = {
"image": [34, 7, 3, 56, 34, 2, 11, 34, 7, 3, 56, 34, 2, 11, 34, 7, 3, 56, 34, 2, 11, 34, 7, 3, 56, 34, 2, 11, 34, 7, 3, 56, 34, 2, 11, 34, 7, 3, 56, 34, 2, 11, 34, 7, 3, 56, 34, 2, 11, 34,
0,0,0,0 ]
} */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment