Skip to content

Instantly share code, notes, and snippets.

@lvijay
Last active August 9, 2019 07:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lvijay/f3c95f1944b895df1b9f30c682b95b2c to your computer and use it in GitHub Desktop.
Save lvijay/f3c95f1944b895df1b9f30c682b95b2c to your computer and use it in GitHub Desktop.
java currying

Java Currying

The associated .java files demonstrate some samples of currying in the Java programming language.

No part of this is actually recommended for use in a production system. The code should be considered from only a pedagogical point of view. Also, standard NO WARRANTY clauses associated with Free Software apply.

Running

The source has been tested on Java 8 and Java 11. The code has no external dependencies and should run with just the vanilla JDK.

Run with jdk11+

Java 11 introduced jep330 that allows us to run java files directly without an explicit compilation step.

$ java src/Currier0.java
hello1234567890
hello1234567890

$ java src/Currier1.java
hello1234567890
7.0

$ java src/Currier2.java
hello:12345:67890

$ java src/Currier3.java
two
7
3.1415926547261828
⊂∪ⓡ╓ү

Run with jdk8-jdk10

$ mkdir bin
$ javac -d bin src/Currier*.java

$ java -cp bin Currier0
hello1234567890
hello1234567890

$ java -cp bin Currier1
hello1234567890
7.0

$ java -cp bin Currier2
hello:12345:67890

$ java -cp bin Currier3
two
7
3.1415926547261828
⊂∪ⓡ╓ү

License

GNU Affero General Public License v3.0

import java.util.function.Function;
public class Currier0 {
public static String concat(String s, Integer i, Integer j) {
return s + i + j;
}
public static
Function<String,
Function<Integer,
Function<Integer, String>>> curry() {
return s -> (i -> (j -> concat(s, i, j)));
}
public static void main(String[] args) {
String direct = concat("hello", 12345, 67890);
System.out.println(direct); // prints hello1234567890
String curried = curry()
.apply("hello")
.apply(12345)
.apply(67890);
System.out.println(curried); // prints hello1234567890
}
}
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.function.Function;
public class Currier1 {
public static String concat1(String s, int i, int j) {
return s + i + j;
}
@SuppressWarnings("rawtypes")
static Function curry1(Method method) {
return (Function) (Object s) -> {
Object self = s;
Object[] args = new Object[3];
return (Function) (Object i) -> {
args[0] = i;
return (Function) (Object j) -> {
args[1] = j;
return (Function) (Object k) -> {
args[2] = k;
try {
return method.invoke(self, i, j, k);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
throw new RuntimeException(e.getCause());
}
};
};
};
};
}
public static final class Distance {
public double distance(double t, double v, double u) {
return v * t + 0.5 * u * t * t;
}
}
public static void main(String[] args) throws Exception {
Method concatM = Currier1.class.getMethod(
"concat1", String.class, int.class, int.class);
@SuppressWarnings("unchecked")
Function<Currier1,
Function<String,
Function<Integer,
Function<Integer, String>>>> curry1 = curry1(concatM);
String curried1static = curry1
.apply(null)
.apply("hello")
.apply(12345)
.apply(67890);
System.out.println(curried1static); // prints hello1234567890
Method distM = Distance.class.getMethod(
"distance", double.class, double.class, double.class);
@SuppressWarnings("unchecked")
Function<Distance,
Function<Double,
Function<Double,
Function<Double, Double>>>> instanceCurry = curry1(distM);
Distance cm = new Distance();
Double distance = instanceCurry
.apply(cm)
.apply(1.0d) // t
.apply(2.0d) // v
.apply(10.0d); // u
System.out.println(distance); // prints 7.0
}
}
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.function.Function;
import java.util.function.Supplier;
public class Currier2 {
public static String concat2(String s, Integer i, Integer j) {
return s + ":" + i + ":" + j;
}
@SuppressWarnings("rawtypes")
public static Function curry2(Method method) {
int parameterCount = method.getParameterCount();
Function f = o -> {
Object self = o;
Object[] args = new Object[parameterCount];
Supplier<?> c = () -> {
try {
return method.invoke(self, args);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
throw new RuntimeException(e.getCause());
}
};
if (parameterCount == 0) {
return c.get();
}
Function[] fns = new Function[parameterCount];
for (int i = 0; i < parameterCount - 1; ++i) {
int j = i;
fns[i] = v -> {
args[j] = v;
return fns[j + 1];
};
}
fns[parameterCount - 1] = a -> {
args[parameterCount - 1] = a;
return c.get();
};
return fns[0];
};
return f;
}
@SuppressWarnings("unchecked")
public static void main(String[] args) throws Exception {
Method concatM = Currier2.class.getMethod(
"concat2", String.class, Integer.class, Integer.class);
Function<Currier2,
Function<String,
Function<Integer,
Function<Integer, String>>>> curry2 = curry2(concatM);
String concatted = curry2
.apply(null)
.apply("hello")
.apply(12345)
.apply(67890);
System.out.println(concatted); // prints hello:12345:67890
}
}
import static java.nio.charset.StandardCharsets.ISO_8859_1;
import static java.nio.charset.StandardCharsets.UTF_8;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.function.Function;
public final class Currier3 {
public static <T, R> Function<T, R> curry3(Method method) {
return new FunctionE<>(method);
}
public static <T, R, K extends T> Function<T, R> curry3(Constructor<K> constructor) {
return new FunctionE<>(constructor);
}
private static final class FunctionE<T, R> implements Function<T, R> {
private static final Object[] EMPTY = new Object[0];
private final Executable executable;
private final int parameterCount;
private final T self;
private final int invocationCount;
private final Object[] env;
public FunctionE(Method method) {
this(method, null, 0, EMPTY);
}
public FunctionE(Constructor<? extends T> constructor) {
this(constructor, null, 0, EMPTY);
}
private FunctionE(
Executable executable,
T self,
int invocationCount,
Object[] env) {
this.executable = executable;
this.parameterCount = executable.getParameterCount();
this.invocationCount = invocationCount;
this.self = self;
this.env = env;
}
@SuppressWarnings("unchecked")
@Override
public R apply(T t) {
final T newSelf;
final Object[] newEnv;
if (invocationCount == 0) {
newSelf = t;
newEnv = env;
} else {
newSelf = self;
newEnv = Arrays.copyOf(env, invocationCount);
newEnv[invocationCount - 1] = t;
}
if (invocationCount == parameterCount) {
return invoke(newSelf, newEnv);
}
return (R) new FunctionE<>(
executable,
newSelf,
1 + invocationCount,
newEnv);
}
@SuppressWarnings("unchecked")
private final R invoke(T self, Object[] args) {
try {
if (executable instanceof Method) {
Method m = (Method) executable;
return (R) m.invoke(self, args);
} else if (executable instanceof Constructor) {
Constructor<R> c = (Constructor<R>) executable;
return c.newInstance(args);
} else {
throw new IllegalStateException("Cannot handle type " + executable.getClass());
}
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (InstantiationException | InvocationTargetException e) {
throw new RuntimeException(e.getCause());
}
}
}
public static void main(String[] args) throws Exception {
/*
* instance method
*/
Method listGet = List.class.getMethod("get", int.class);
Function<List<?>,
Function<Integer, String>> c3 = curry3(listGet);
List<String> list = Arrays.asList("zero", "one", "two");
String two = c3
.apply(list)
.apply(2);
System.out.println(two); // prints two
/*
* static method
*/
Method binarySearch = Collections.class.getMethod("binarySearch",
List.class, Object.class, Comparator.class);
Function<Void,
Function<List<? extends String>,
Function<String,
Function<Comparator<? super String>, Integer>>>> c3static = curry3(binarySearch);
Integer index = c3static
.apply(null) // static method, no instance
.apply(Arrays.asList("0;1;2;3;4;5;6;7".split(";")))
.apply("7")
.apply((s1, s2) -> s1.compareTo(s2));
System.out.println(index); // prints 7
/*
* No arg method
*/
Method nextDouble = Random.class.getMethod("nextDouble");
Function<Random, Double> randomInvoke = curry3(nextDouble);
Random random = new Random(1753592585);
double randomDouble = randomInvoke.apply(random).doubleValue();
System.out.println(10 * randomDouble); // prints 3.1415926547261828
/*
* constructor
*/
Constructor<String> stringConstructor = String.class.getConstructor(byte[].class, Charset.class);
Function<String,
Function<byte[],
Function<Charset, String>>> c3cons = curry3(stringConstructor);
String byteArray = c3cons
.apply(null)
.apply("⊂∪ⓡ╓ү".getBytes(ISO_8859_1))
.apply(UTF_8);
System.out.println(byteArray); // prints ⊂∪ⓡ╓ү
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment