Skip to content

Instantly share code, notes, and snippets.

@GeoffChurch
Created August 22, 2021 04:10
Show Gist options
  • Save GeoffChurch/c90fdd2c002c948bba1c037a74b49c9d to your computer and use it in GitHub Desktop.
Save GeoffChurch/c90fdd2c002c948bba1c037a74b49c9d to your computer and use it in GitHub Desktop.
Multidimensional arrays in Prolog
:- use_module(library(clpfd)).
%% at(Dimensions, Index, Array, Element) -- array access.
at([], [], X, X).
at([D|Ds], [I|Is], A, X) :-
length(A, D),
nth1(I, A, AI),
at(Ds, Is, AI, X).
%% `Ds` is the size of each dimension of an array. `Is` is the list of all indices into an array with those dimensions.
indices(Ds, Is) :- findall(Idx, at(Ds, Idx, _, _), Is).
%% Map P elementwise over every array in As. For example, to initialize J and K to 3x3 matrices of all 1s and all 2s respectively:
%% map([3,3], [[1]]>>true, [J]), map([3,3], [[X,Y]]>>(Y#=X+X), [J, K]).
map(Ds, P, As) :-
indices(Ds, Indices),
maplist(
{Ds, P, As}/[I]>>(maplist(at(Ds, I), As, Xs), call(P, Xs)),
Indices).
%% Map P elementwise over As, with the current index as P's first parameter. For example, see identity/2.
mapi(Ds, P, As) :-
indices(Ds, Indices),
maplist(
{Ds, P, As}/[I]>>(maplist(at(Ds, I), As, Xs), call(P, I, Xs)),
Indices).
eq(X, X, 1).
eq(X, Y, 0) :- dif(X, Y).
identity(D, I) :-
mapi(
[D,D],
[[I1,I2], [X]]>>eq(I1, I2, X),
[I]).
%% identity(3, I), map([3,3], [[5]]>>true, [A]), hadamard([3,3], I, A, B).
hadamard(Ds, A, B, C) :- map(Ds, [[X,Y,Z]]>>(Z #= X * Y), [A,B,C]).
%% Kronecker product of A and B. TODO decide n-d behavior and construct w/o reshape.
kronecker(ADims, BDims, A, B, C) :-
length(ADims, LenADims),
append(ADims, BDims, NestedDims),
mapi(NestedDims, {A,B,ADims,BDims,LenADims}/[I, [X]]>>(
length(AI, LenADims),
append(AI, BI, I),
at(ADims, AI, A, AX),
at(BDims, BI, B, BX),
X #= AX * BX
), [Nested]),
hadamard([LenADims], ADims, BDims, CDims),
reshape(NestedDims, Nested, CDims, C).
%% The array A, with the specified size, has its elements listed by L.
flatten([], A, [A]).
flatten([D|Ds], A, L) :-
length(A, D),
maplist(flatten(Ds), A, FlatsRev),
reverse(FlatsRev, Flats),
foldl(append, Flats, [], L).
reshape(InDs, In, OutDs, Out) :-
flatten(InDs, In, InFlat),
flatten(OutDs, Out, InFlat).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment