Created
July 7, 2022 13:22
-
-
Save RobGeada/07447d1074e37690a4c21233d42a3233 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import drools_integrators.DroolsWrapper; | |
import drools_integrators.Explainers; | |
import org.apache.commons.math3.analysis.function.Exp; | |
import org.junit.Test; | |
import org.kie.api.KieServices; | |
import org.kie.api.runtime.KieContainer; | |
import org.kie.trustyai.explainability.model.Feature; | |
import org.kie.trustyai.explainability.model.FeatureFactory; | |
import org.kie.trustyai.explainability.model.Output; | |
import org.kie.trustyai.explainability.model.PredictionInput; | |
import org.kie.trustyai.explainability.model.Type; | |
import org.kie.trustyai.explainability.model.Value; | |
import org.kie.trustyai.explainability.model.domain.FeatureDomain; | |
import org.kie.trustyai.explainability.model.domain.NumericalFeatureDomain; | |
import org.kie.trustyai.explainability.model.domain.ObjectFeatureDomain; | |
import rulebases.buspass.Person; | |
import rulebases.cashflow.Account; | |
import rulebases.cashflow.AccountPeriod; | |
import rulebases.cashflow.CashFlow; | |
import rulebases.cashflow.CashFlowType; | |
import rulebases.cost.City; | |
import rulebases.cost.CostCalculationRequest; | |
import rulebases.cost.Order; | |
import rulebases.cost.OrderLine; | |
import rulebases.cost.Product; | |
import rulebases.cost.Step; | |
import rulebases.cost.Trip; | |
import java.text.ParseException; | |
import java.util.ArrayList; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.concurrent.ExecutionException; | |
import java.util.concurrent.TimeoutException; | |
import java.util.function.Supplier; | |
import java.util.stream.Collectors; | |
import java.util.stream.Stream; | |
import static rulebases.cashflow.CashFlowMain.date; | |
public class TrustyAIDroolsBlogcode { | |
KieServices ks = KieServices.Factory.get(); | |
KieContainer kieContainer = ks.getKieClasspathContainer(); | |
public void main() throws ExecutionException, InterruptedException, TimeoutException { | |
// SETUP ======================================================================== | |
// ============================================================================== | |
Supplier<List<Object>> objectSupplier = () -> { | |
// define Trip | |
City cityOfShangai = new City(City.ShangaiCityName); | |
City cityOfRotterdam = new City(City.RotterdamCityName); | |
City cityOfTournai = new City(City.TournaiCityName); | |
City cityOfLille = new City(City.LilleCityName); | |
Step step1 = new Step(cityOfShangai, cityOfRotterdam, 22000, Step.Ship_TransportType); | |
Step step2 = new Step(cityOfRotterdam, cityOfTournai, 300, Step.train_TransportType); | |
Step step3 = new Step(cityOfTournai, cityOfLille, 20, Step.truck_TransportType); | |
Trip trip = new Trip("trip1"); | |
trip.getSteps().add(step1); | |
trip.getSteps().add(step2); | |
trip.getSteps().add(step3); | |
// define Order | |
Order order = new Order("toExplain"); | |
Product drillProduct = new Product("Drill", 0.2, 0.4, 0.3, 2, Product.transportType_pallet); | |
Product screwDriverProduct = new Product("Screwdriver", 0.03, 0.02, 0.2, 0.2, Product.transportType_pallet); | |
Product sandProduct = new Product("Sand", 0.0, 0.0, 0.0, 0.0, Product.transportType_bulkt); | |
Product gravelProduct = new Product("Gravel", 0.0, 0.0, 0.0, 0.0, Product.transportType_bulkt); | |
Product furnitureProduct = new Product("Furniture", 0.0, 0.0, 0.0, 0.0, Product.transportType_individual); | |
order.getOrderLines().add(new OrderLine(1000, drillProduct)); | |
order.getOrderLines().add(new OrderLine(35000.0, sandProduct)); | |
order.getOrderLines().add(new OrderLine(14000.0, gravelProduct)); | |
order.getOrderLines().add(new OrderLine(500, furnitureProduct)); | |
// combine Trip and Order into CostCalculationRequest | |
CostCalculationRequest request = new CostCalculationRequest(); | |
request.setTrip(trip); | |
request.setOrder(order); | |
return List.of(request); | |
}; | |
// initialize the wrapper | |
DroolsWrapper droolsWrapper = new DroolsWrapper(kieContainer, "CostRulesKS", objectSupplier, "P1"); | |
// setup Feature extraction | |
//droolsWrapper.displayFeatureCandidates(); | |
// FEATURE SELECTION ============================================================= | |
// ============================================================================== | |
// apply filters | |
droolsWrapper.setFeatureExtractorFilters( | |
List.of( | |
"(orderLines\\[\\d+\\].weight)", | |
"(orderLines\\[\\d+\\].numberItems)", | |
"(trip.steps\\[\\d+\\].transportType)" | |
) | |
); | |
//droolsWrapper.displayFeatureCandidates(); | |
// set feature type specifications | |
HashMap<String, Type> featureTypeOverrides = new HashMap<>(); | |
featureTypeOverrides.put("trip.steps\\[\\d+\\].transportType", Type.CATEGORICAL); | |
droolsWrapper.setFeatureTypeOverrides(featureTypeOverrides); | |
// set feature domains | |
for (Feature f : droolsWrapper.featureExtractor(objectSupplier.get()).keySet()) { | |
if (f.getName().contains("transportType")) { | |
// transport type can be truck, train, ship | |
FeatureDomain<Object> fd = ObjectFeatureDomain.create(List.of(Step.truck_TransportType, Step.train_TransportType, Step.Ship_TransportType)); | |
droolsWrapper.addFeatureDomain(f.getName(), fd); | |
} else { | |
// let numeric features range from 0 to original value | |
FeatureDomain nfd = NumericalFeatureDomain.create(0., ((Number) f.getValue().getUnderlyingObject()).doubleValue()); | |
droolsWrapper.addFeatureDomain(f.getName(), nfd); | |
} | |
} | |
droolsWrapper.displayFeatureCandidates(); | |
// OUTPUT SELECTION ============================================================== | |
// ============================================================================== | |
// exclude the following objects | |
droolsWrapper.setExcludedOutputObjects(List.of( | |
"pallets", "LeftToDistribute", "cost.Product", "cost.OrderLine", "java.lang.Double", "costElements", "Pallet", "City", "Step", "org.drools.core.reteoo.InitialFactImpl", "java.util.ArrayList")); | |
// exclude the following field names | |
droolsWrapper.setExcludedOutputFields(List.of("pallets", "order", "trip", "step", "distance", "transportType", "city", "Step")); | |
// only look at consequences of the following rules | |
droolsWrapper.setIncludedOutputRules(List.of("CalculateTotal")); | |
droolsWrapper.generateOutputCandidates(true); | |
// SHAP ========================================================================== | |
// ============================================================================== | |
//create background | |
List<Feature> backgroundFeatures = new ArrayList<>(); | |
PredictionInput samplePI = droolsWrapper.getSamplePredictionInput(); | |
for (int j = 0; j < samplePI.getFeatures().size(); j++) { | |
Feature f = samplePI.getFeatures().get(j); | |
if (f.getName().contains("transportType")) { | |
backgroundFeatures.add(FeatureFactory.copyOf(f, new Value(Step.truck_TransportType))); | |
} else { | |
backgroundFeatures.add(FeatureFactory.copyOf(f, new Value(0.))); | |
} | |
} | |
List<PredictionInput> background = List.of(new PredictionInput(backgroundFeatures)); | |
// run SHAP | |
Explainers.runSHAP(droolsWrapper, background); | |
// COUNTERFACTUALS =============================================================== | |
// ============================================================================== | |
// generate goal | |
List<Output> goal = List.of(new Output( | |
"rulebases.cost.CostCalculationRequest.totalCost", | |
Type.NUMBER, | |
new Value(2_000_000), // this is where we set our goal to 2 million | |
0.0d) | |
); | |
Explainers.runCounterfactualSearch( | |
droolsWrapper, | |
goal, | |
.01, //want to get within 1% of goal | |
60L //allow 60 seconds of search time | |
); | |
// COUNTERFACTUALS (more complex) ================================ | |
droolsWrapper.selectOutputIndicesFromCandidates(List.of(4,5)); | |
goal = List.of( | |
new Output("rulebases.cost.CostCalculationRequest.numPallets", Type.NUMBER, new Value(540), 0.0d), | |
new Output("rulebases.cost.CostCalculationRequest.totalTaxCost", Type.NUMBER, new Value(200), 0.0d) | |
); | |
Explainers.runCounterfactualSearch(droolsWrapper, | |
goal, | |
.005, //aim for within 0.5% of goals | |
300L //allow for 5 minutes of search time | |
); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment