Skip to content

Instantly share code, notes, and snippets.

@RobGeada
Created July 7, 2022 13:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RobGeada/07447d1074e37690a4c21233d42a3233 to your computer and use it in GitHub Desktop.
Save RobGeada/07447d1074e37690a4c21233d42a3233 to your computer and use it in GitHub Desktop.
import drools_integrators.DroolsWrapper;
import drools_integrators.Explainers;
import org.apache.commons.math3.analysis.function.Exp;
import org.junit.Test;
import org.kie.api.KieServices;
import org.kie.api.runtime.KieContainer;
import org.kie.trustyai.explainability.model.Feature;
import org.kie.trustyai.explainability.model.FeatureFactory;
import org.kie.trustyai.explainability.model.Output;
import org.kie.trustyai.explainability.model.PredictionInput;
import org.kie.trustyai.explainability.model.Type;
import org.kie.trustyai.explainability.model.Value;
import org.kie.trustyai.explainability.model.domain.FeatureDomain;
import org.kie.trustyai.explainability.model.domain.NumericalFeatureDomain;
import org.kie.trustyai.explainability.model.domain.ObjectFeatureDomain;
import rulebases.buspass.Person;
import rulebases.cashflow.Account;
import rulebases.cashflow.AccountPeriod;
import rulebases.cashflow.CashFlow;
import rulebases.cashflow.CashFlowType;
import rulebases.cost.City;
import rulebases.cost.CostCalculationRequest;
import rulebases.cost.Order;
import rulebases.cost.OrderLine;
import rulebases.cost.Product;
import rulebases.cost.Step;
import rulebases.cost.Trip;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static rulebases.cashflow.CashFlowMain.date;
public class TrustyAIDroolsBlogcode {
KieServices ks = KieServices.Factory.get();
KieContainer kieContainer = ks.getKieClasspathContainer();
public void main() throws ExecutionException, InterruptedException, TimeoutException {
// SETUP ========================================================================
// ==============================================================================
Supplier<List<Object>> objectSupplier = () -> {
// define Trip
City cityOfShangai = new City(City.ShangaiCityName);
City cityOfRotterdam = new City(City.RotterdamCityName);
City cityOfTournai = new City(City.TournaiCityName);
City cityOfLille = new City(City.LilleCityName);
Step step1 = new Step(cityOfShangai, cityOfRotterdam, 22000, Step.Ship_TransportType);
Step step2 = new Step(cityOfRotterdam, cityOfTournai, 300, Step.train_TransportType);
Step step3 = new Step(cityOfTournai, cityOfLille, 20, Step.truck_TransportType);
Trip trip = new Trip("trip1");
trip.getSteps().add(step1);
trip.getSteps().add(step2);
trip.getSteps().add(step3);
// define Order
Order order = new Order("toExplain");
Product drillProduct = new Product("Drill", 0.2, 0.4, 0.3, 2, Product.transportType_pallet);
Product screwDriverProduct = new Product("Screwdriver", 0.03, 0.02, 0.2, 0.2, Product.transportType_pallet);
Product sandProduct = new Product("Sand", 0.0, 0.0, 0.0, 0.0, Product.transportType_bulkt);
Product gravelProduct = new Product("Gravel", 0.0, 0.0, 0.0, 0.0, Product.transportType_bulkt);
Product furnitureProduct = new Product("Furniture", 0.0, 0.0, 0.0, 0.0, Product.transportType_individual);
order.getOrderLines().add(new OrderLine(1000, drillProduct));
order.getOrderLines().add(new OrderLine(35000.0, sandProduct));
order.getOrderLines().add(new OrderLine(14000.0, gravelProduct));
order.getOrderLines().add(new OrderLine(500, furnitureProduct));
// combine Trip and Order into CostCalculationRequest
CostCalculationRequest request = new CostCalculationRequest();
request.setTrip(trip);
request.setOrder(order);
return List.of(request);
};
// initialize the wrapper
DroolsWrapper droolsWrapper = new DroolsWrapper(kieContainer, "CostRulesKS", objectSupplier, "P1");
// setup Feature extraction
//droolsWrapper.displayFeatureCandidates();
// FEATURE SELECTION =============================================================
// ==============================================================================
// apply filters
droolsWrapper.setFeatureExtractorFilters(
List.of(
"(orderLines\\[\\d+\\].weight)",
"(orderLines\\[\\d+\\].numberItems)",
"(trip.steps\\[\\d+\\].transportType)"
)
);
//droolsWrapper.displayFeatureCandidates();
// set feature type specifications
HashMap<String, Type> featureTypeOverrides = new HashMap<>();
featureTypeOverrides.put("trip.steps\\[\\d+\\].transportType", Type.CATEGORICAL);
droolsWrapper.setFeatureTypeOverrides(featureTypeOverrides);
// set feature domains
for (Feature f : droolsWrapper.featureExtractor(objectSupplier.get()).keySet()) {
if (f.getName().contains("transportType")) {
// transport type can be truck, train, ship
FeatureDomain<Object> fd = ObjectFeatureDomain.create(List.of(Step.truck_TransportType, Step.train_TransportType, Step.Ship_TransportType));
droolsWrapper.addFeatureDomain(f.getName(), fd);
} else {
// let numeric features range from 0 to original value
FeatureDomain nfd = NumericalFeatureDomain.create(0., ((Number) f.getValue().getUnderlyingObject()).doubleValue());
droolsWrapper.addFeatureDomain(f.getName(), nfd);
}
}
droolsWrapper.displayFeatureCandidates();
// OUTPUT SELECTION ==============================================================
// ==============================================================================
// exclude the following objects
droolsWrapper.setExcludedOutputObjects(List.of(
"pallets", "LeftToDistribute", "cost.Product", "cost.OrderLine", "java.lang.Double", "costElements", "Pallet", "City", "Step", "org.drools.core.reteoo.InitialFactImpl", "java.util.ArrayList"));
// exclude the following field names
droolsWrapper.setExcludedOutputFields(List.of("pallets", "order", "trip", "step", "distance", "transportType", "city", "Step"));
// only look at consequences of the following rules
droolsWrapper.setIncludedOutputRules(List.of("CalculateTotal"));
droolsWrapper.generateOutputCandidates(true);
// SHAP ==========================================================================
// ==============================================================================
//create background
List<Feature> backgroundFeatures = new ArrayList<>();
PredictionInput samplePI = droolsWrapper.getSamplePredictionInput();
for (int j = 0; j < samplePI.getFeatures().size(); j++) {
Feature f = samplePI.getFeatures().get(j);
if (f.getName().contains("transportType")) {
backgroundFeatures.add(FeatureFactory.copyOf(f, new Value(Step.truck_TransportType)));
} else {
backgroundFeatures.add(FeatureFactory.copyOf(f, new Value(0.)));
}
}
List<PredictionInput> background = List.of(new PredictionInput(backgroundFeatures));
// run SHAP
Explainers.runSHAP(droolsWrapper, background);
// COUNTERFACTUALS ===============================================================
// ==============================================================================
// generate goal
List<Output> goal = List.of(new Output(
"rulebases.cost.CostCalculationRequest.totalCost",
Type.NUMBER,
new Value(2_000_000), // this is where we set our goal to 2 million
0.0d)
);
Explainers.runCounterfactualSearch(
droolsWrapper,
goal,
.01, //want to get within 1% of goal
60L //allow 60 seconds of search time
);
// COUNTERFACTUALS (more complex) ================================
droolsWrapper.selectOutputIndicesFromCandidates(List.of(4,5));
goal = List.of(
new Output("rulebases.cost.CostCalculationRequest.numPallets", Type.NUMBER, new Value(540), 0.0d),
new Output("rulebases.cost.CostCalculationRequest.totalTaxCost", Type.NUMBER, new Value(200), 0.0d)
);
Explainers.runCounterfactualSearch(droolsWrapper,
goal,
.005, //aim for within 0.5% of goals
300L //allow for 5 minutes of search time
);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment