Skip to content

Instantly share code, notes, and snippets.

@aadnk
Last active August 3, 2017 17:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aadnk/906c809671db5297daf309c292bea236 to your computer and use it in GitHub Desktop.
Save aadnk/906c809671db5297daf309c292bea236 to your computer and use it in GitHub Desktop.
Test of the pattern matching pseudo code in "Pattern Matching with Brian Goetz @briangoetz"
package com.comphenix.test;
import java.util.Objects;
public class PatternMatchingTest {
public static void main(String[] args) throws Exception {
AddNode zeroPlusOne = new AddNode(IntNode.ZERO, IntNode.ONE);
AddNode onePlusOne = new AddNode(IntNode.ONE, IntNode.ONE);
System.out.println("Testing our proposed fix: ");
System.out.println("simplify(" + zeroPlusOne + ") => " + simplifyB(zeroPlusOne));
System.out.println("simplify(" + onePlusOne + ") => " + simplifyB(onePlusOne));
System.out.println("\nTesting original code: ");
System.out.println("simplify(" + zeroPlusOne + ") => " + simplifyA(zeroPlusOne));
System.out.println("simplify(" + onePlusOne + ") => " + simplifyA(onePlusOne));
}
// Directly adapted from pseudo code in the following Brian Goetz Java talk:
// https://www.youtube.com/watch?v=n3_8YcYKScw&feature=youtu.be
private static Node simplifyA(Node n) {
if (n instanceof IntNode) {
return n;
} else if (n instanceof NegNode) {
NegNode negNode = (NegNode) n;
if (negNode.inner instanceof NegNode) {
return simplifyA(negNode.inner);
} else {
return simplifyA(new NegNode(simplifyB(negNode.inner)));
}
} else if (n instanceof AddNode) {
AddNode addNode = (AddNode) n;
if (IntNode.ZERO.equals(addNode.left)) {
return simplifyA(addNode.right);
} else if (IntNode.ZERO.equals(addNode.right)) {
return simplifyA(addNode.left);
} else {
return simplifyA(new AddNode(simplifyB(addNode.left), simplifyB(addNode.right)));
}
} else if (n instanceof MulNode) {
MulNode mulNode = (MulNode) n;
if (IntNode.ONE.equals(mulNode.left)) {
return simplifyA(mulNode.right);
} else if (IntNode.ONE.equals(mulNode.right)) {
return simplifyA(mulNode.left);
} else if (IntNode.ZERO.equals(mulNode.left)) {
return IntNode.ZERO;
} else if (IntNode.ZERO.equals(mulNode.right)) {
return IntNode.ZERO;
} else {
return simplifyA(new MulNode(simplifyB(mulNode.left), simplifyB(mulNode.right)));
}
}
// We can't handle this node yet
throw new IllegalArgumentException("Unknown node " + n);
}
private static Node simplifyB(Node n) {
return simplifyB(n, null);
}
// A proposed fix - do not attempt to simplify a node that did not change when we simplified its children.
private static Node simplifyB(Node n, Node source) {
// Prevent infinite loops
if (Objects.equals(n, source)) {
return n;
}
if (n instanceof IntNode) {
return n;
} else if (n instanceof NegNode) {
NegNode negNode = (NegNode) n;
if (negNode.inner instanceof NegNode) {
return simplifyB(negNode.inner, n);
} else {
return simplifyB(new NegNode(simplifyB(negNode.inner)), n);
}
} else if (n instanceof AddNode) {
AddNode addNode = (AddNode) n;
if (IntNode.ZERO.equals(addNode.left)) {
return simplifyB(addNode.right, n);
} else if (IntNode.ZERO.equals(addNode.right)) {
return simplifyB(addNode.left, n);
} else {
return simplifyB(new AddNode(simplifyB(addNode.left), simplifyB(addNode.right)), n);
}
} else if (n instanceof MulNode) {
MulNode mulNode = (MulNode) n;
if (IntNode.ONE.equals(mulNode.left)) {
return simplifyB(mulNode.right, n);
} else if (IntNode.ONE.equals(mulNode.right)) {
return simplifyB(mulNode.left, n);
} else if (IntNode.ZERO.equals(mulNode.left)) {
return IntNode.ZERO;
} else if (IntNode.ZERO.equals(mulNode.right)) {
return IntNode.ZERO;
} else {
return simplifyB(new MulNode(simplifyB(mulNode.left), simplifyB(mulNode.right)), n);
}
}
// We can't handle this node yet
throw new IllegalArgumentException("Unknown node " + n);
}
// *** Dependencies ***
static abstract class Node {
}
static class IntNode extends Node {
static final IntNode ZERO = new IntNode(0);
static final IntNode ONE = new IntNode(1);
final int value;
IntNode(int value) {
this.value = value;
}
@Override
public boolean equals(Object o) {
return this == o || !(o == null || getClass() != o.getClass()) &&
value == ((IntNode) o).value;
}
@Override
public int hashCode() {
return 31 * value;
}
@Override
public String toString() {
return String.valueOf(value);
}
}
static class NegNode extends Node {
final Node inner;
NegNode(Node inner) {
this.inner = inner;
}
@Override
public boolean equals(Object o) {
return this == o || !(o == null || getClass() != o.getClass()) &&
Objects.equals(inner, ((NegNode) o).inner);
}
@Override
public int hashCode() {
return inner != null ? inner.hashCode() : 0;
}
@Override
public String toString() {
return "(-" + inner + ")";
}
}
static class AddNode extends Node {
final Node left;
final Node right;
AddNode(Node left, Node right) {
this.left = left;
this.right = right;
}
@Override
public boolean equals(Object o) {
return this == o || !(o == null || getClass() != o.getClass()) &&
Objects.equals(left, ((AddNode) o).left) &&
Objects.equals(right, ((AddNode) o).right);
}
@Override
public int hashCode() {
return Objects.hash(left, right);
}
@Override
public String toString() {
return "(" + left + " + " + right + ")";
}
}
static class MulNode extends Node {
final Node left;
final Node right;
MulNode(Node left, Node right) {
this.left = left;
this.right = right;
}
@Override
public boolean equals(Object o) {
return this == o || !(o == null || getClass() != o.getClass()) &&
Objects.equals(left, ((MulNode) o).left) &&
Objects.equals(right, ((MulNode) o).right);
}
@Override
public int hashCode() {
return Objects.hash(left, right);
}
@Override
public String toString() {
return "(" + left + " * " + right + ")";
}
}
}
// Original pseudo code
Node simplify(Node n) {
return switch(n) {
case IntNode -> n;
case NegNode(NegNode(var n)) -> simplify(n);
case NegNode(var n) -> simplify(new NegNode(simplify(n)));
case AddNode(IntNode(0), var right) -> simplify(right);
case AddNode(var left, IntNode(0)) -> simplify(left);
case AddNode(var left, var right)
-> simplify(new AddNode(simplify(left), simplify(right)));
case MulNode(IntNode(1), var right) -> simplify(right);
case MulNode(var left, IntNode(1)) -> simplify(left);
case MulNode(IntNode(0), var right) -> new IntNode(0);
case MulNode(var left, IntNode(0)) -> new IntNode(0);
case MulNode(var left, var right)
-> simplify(new MulNode(simplify(left), simplify(right)));
}
}
// Fixed pseudo code
Node simplify(Node n, Node s) {
if (Objects.equals(n, s)) return n;
return switch(n) {
case IntNode -> n;
case NegNode(NegNode(var x)) -> simplify(x, n);
case NegNode(var x) -> simplify(new NegNode(simplify(x, n)), n);
case AddNode(IntNode(0), var right) -> simplify(right, n);
case AddNode(var left, IntNode(0)) -> simplify(left, n);
case AddNode(var left, var right)
-> simplify(new AddNode(simplify(left, n), simplify(right, n)), n);
case MulNode(IntNode(1), var right) -> simplify(right, n);
case MulNode(var left, IntNode(1)) -> simplify(left, n);
case MulNode(IntNode(0), var right) -> new IntNode(0);
case MulNode(var left, IntNode(0)) -> new IntNode(0);
case MulNode(var left, var right)
-> simplify(new MulNode(simplify(left, n), simplify(right, n)), n);
default -> throw new IllegalArgumentException("Not recognized: " + n);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment