Skip to content

Instantly share code, notes, and snippets.

@mp911de
Created December 22, 2015 15:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mp911de/464c1e0e2d19dfc904a7 to your computer and use it in GitHub Desktop.
Save mp911de/464c1e0e2d19dfc904a7 to your computer and use it in GitHub Desktop.
Wiener Process with Java and JFreeChart
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.title.TextTitle;
import org.jfree.data.time.Month;
import org.jfree.data.time.TimeTableXYDataset;
import org.jfree.data.xy.TableXYDataset;
import org.jfree.ui.RectangleEdge;
import org.jfree.ui.VerticalAlignment;
import javax.swing.*;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;
import java.awt.*;
import java.text.DecimalFormat;
import java.util.List;
/**
* @author <a href="mailto:mpaluch@paluch.biz">Mark Paluch</a>
*/
public class Main {
JFrame frame;
InputModel inputModel = new InputModel();
JFreeChart chart;
JSpinner maxValue;
public static void main(String[] args) {
new Main().run();
}
private Main() {
createGui();
refresh();
}
/**
* Refresh the data and update the max value to support a fixed scaling of the range axis.
*/
private void refresh() {
XYPlot xyPlot = (XYPlot) chart.getPlot();
xyPlot.getDomainAxis().setAutoRange(!inputModel.fixedScaling);
TimeTableXYDataset dataSet = createDataSet();
if (inputModel.fixedScaling) {
xyPlot.getRangeAxis().setLowerBound(0);
xyPlot.getRangeAxis().setUpperBound(inputModel.max);
} else {
double max = 0;
for (int i = 0; i < dataSet.getSeriesCount(); i++) {
max = Math.max(max, dataSet.getY(i, dataSet.getItemCount() - 1).doubleValue());
}
inputModel.max = max;
maxValue.setValue((int) Math.round(inputModel.max / 1000) * 1000);
}
xyPlot.setDataset(dataSet);
}
/**
* Creates the JFreeChart dataset.
*
* @return
*/
private TimeTableXYDataset createDataSet() {
DecimalFormat df = new DecimalFormat("0.00");
double[] breaks = {0.025, 1 / 6d, 0.5, 1 - 1 / 6d, 0.975};
List<WienerProcess.PointInTime> result = WienerProcess
.getProjection(inputModel.mu, inputModel.sigma, inputModel.years,
inputModel.initialValue, inputModel.monthlyValue, breaks);
return convertToTimeTableDataset(df, breaks, result);
}
/**
* Convert the Wiener process data to a JFreeChart data set.
*/
private static TimeTableXYDataset convertToTimeTableDataset(DecimalFormat df, double[] breaks,
List<WienerProcess.PointInTime> result) {
TimeTableXYDataset collection = new TimeTableXYDataset();
for (WienerProcess.PointInTime pointInTime : result) {
double[] bounds = pointInTime.bounds;
for (int j = 0; j < bounds.length; j++) {
collection.add(new Month(pointInTime.localDate.getMonthValue(), pointInTime.localDate.getYear()),
bounds[j],
toQuantile(j));
}
collection.add(new Month(pointInTime.localDate.getMonthValue(), pointInTime.localDate.getYear()),
pointInTime.value,
"Value");
collection.add(new Month(pointInTime.localDate.getMonthValue(), pointInTime.localDate.getYear()),
pointInTime.simulation,
"Simulation");
}
return collection;
}
private static Comparable toQuantile(int j) {
if (j == 0) {
return "95th quantile lower";
}
if (j == 4) {
return "95th quantile upper";
}
if (j == 1) {
return "66th quantile lower";
}
if (j == 3) {
return "66th quantile upper";
}
return "Median";
}
/**
* Create a swing gui to visualize the data.
*/
private void createGui() {
frame = new JFrame("Wiener Process Demo");
frame.setSize(new Dimension(640, 480));
TimeTableXYDataset collection = new TimeTableXYDataset();
// create the chart...
chart = createChart(collection);
chart.getPlot().setBackgroundPaint(Color.decode("#ffffff"));
chart.getXYPlot().getRangeAxis().setTickMarkPaint(Color.decode("#999999"));
chart.getXYPlot().setRangeGridlinePaint(Color.decode("#999999"));
final ChartPanel chartPanel = new ChartPanel(chart);
chartPanel.setPreferredSize(new Dimension(500, 270));
chartPanel.setEnforceFileExtensions(false);
frame.setLayout(new GridBagLayout());
GridBagConstraints c = new GridBagConstraints();
c.gridx = 0;
c.gridy = 0;
c.anchor = GridBagConstraints.NORTH;
c.fill = GridBagConstraints.BOTH;
c.weightx = 1;
c.weighty = 5;
frame.add(chartPanel, c);
final JPanel panel1 = new JPanel();
panel1.setLayout(new GridBagLayout());
final JLabel label1 = new JLabel();
label1.setText("Mu");
GridBagConstraints gbc;
gbc = new GridBagConstraints();
gbc.gridx = 0;
gbc.gridy = 0;
gbc.weightx = 2.0;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
panel1.add(label1, gbc);
final JSpinner mu = new JSpinner();
gbc = new GridBagConstraints();
gbc.gridx = 1;
gbc.gridy = 0;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
gbc.fill = GridBagConstraints.HORIZONTAL;
panel1.add(mu, gbc);
final JSpinner sigma = new JSpinner();
gbc = new GridBagConstraints();
gbc.gridx = 1;
gbc.gridy = 1;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
gbc.fill = GridBagConstraints.HORIZONTAL;
panel1.add(sigma, gbc);
final JLabel label2 = new JLabel();
label2.setText("Sigma");
gbc = new GridBagConstraints();
gbc.gridx = 0;
gbc.gridy = 1;
gbc.weightx = 2.0;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
panel1.add(label2, gbc);
final JLabel label3 = new JLabel();
label3.setText("Initial value");
gbc = new GridBagConstraints();
gbc.gridx = 0;
gbc.gridy = 2;
gbc.weightx = 2.0;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
panel1.add(label3, gbc);
final JSpinner initialValue = new JSpinner();
gbc = new GridBagConstraints();
gbc.gridx = 1;
gbc.gridy = 2;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
gbc.fill = GridBagConstraints.HORIZONTAL;
panel1.add(initialValue, gbc);
final JLabel label4 = new JLabel();
label4.setText("Monthly value");
gbc = new GridBagConstraints();
gbc.gridx = 0;
gbc.gridy = 3;
gbc.weightx = 2.0;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
panel1.add(label4, gbc);
final JSpinner monthlyValue = new JSpinner();
gbc = new GridBagConstraints();
gbc.gridx = 1;
gbc.gridy = 3;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
gbc.fill = GridBagConstraints.HORIZONTAL;
panel1.add(monthlyValue, gbc);
final JLabel label5 = new JLabel();
label5.setText("Duration");
gbc = new GridBagConstraints();
gbc.gridx = 0;
gbc.gridy = 4;
gbc.weightx = 2.0;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
panel1.add(label5, gbc);
final JSpinner duration = new JSpinner();
gbc = new GridBagConstraints();
gbc.gridx = 1;
gbc.gridy = 4;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
gbc.fill = GridBagConstraints.HORIZONTAL;
panel1.add(duration, gbc);
final JCheckBox fixedScaleTrueFalse = new JCheckBox();
fixedScaleTrueFalse.setText("Fixed scale");
gbc = new GridBagConstraints();
gbc.gridx = 0;
gbc.gridy = 5;
gbc.weightx = 2.0;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
panel1.add(fixedScaleTrueFalse, gbc);
maxValue = new JSpinner();
gbc = new GridBagConstraints();
gbc.gridx = 1;
gbc.gridy = 5;
gbc.insets = new Insets(0, 5, 0, 0);
gbc.anchor = GridBagConstraints.WEST;
gbc.fill = GridBagConstraints.HORIZONTAL;
panel1.add(maxValue, gbc);
c = new GridBagConstraints();
c.anchor = GridBagConstraints.WEST;
c.fill = GridBagConstraints.HORIZONTAL;
c.weightx = 1;
c.weighty = 1;
c.gridx = 0;
c.gridy = 1;
frame.add(panel1, c);
// a bit of data binding
mu.setModel(new SpinnerNumberModel(inputModel.mu, 0, 0.99, 0.01));
mu.addChangeListener(new ChangeListener() {
public void stateChanged(ChangeEvent e) {
inputModel.mu = ((Number) mu.getModel().getValue()).doubleValue();
refresh();
}
});
sigma.setModel(new SpinnerNumberModel(inputModel.sigma, 0, 0.99, 0.01));
sigma.addChangeListener(new ChangeListener() {
public void stateChanged(ChangeEvent e) {
inputModel.sigma = ((Number) sigma.getModel().getValue()).doubleValue();
refresh();
}
});
initialValue.setModel(new SpinnerNumberModel(inputModel.initialValue, 0, 100000, 1000));
initialValue.addChangeListener(new ChangeListener() {
public void stateChanged(ChangeEvent e) {
inputModel.initialValue = ((Number) initialValue.getModel().getValue()).intValue();
refresh();
}
});
monthlyValue.setModel(new SpinnerNumberModel(inputModel.monthlyValue, 0, 100000, 50));
monthlyValue.addChangeListener(new ChangeListener() {
public void stateChanged(ChangeEvent e) {
inputModel.monthlyValue = ((Number) monthlyValue.getModel().getValue()).intValue();
refresh();
}
});
duration.setModel(new SpinnerNumberModel(inputModel.years, 0, 25, 1));
duration.addChangeListener(new ChangeListener() {
public void stateChanged(ChangeEvent e) {
inputModel.years = ((Number) duration.getModel().getValue()).intValue();
refresh();
}
});
fixedScaleTrueFalse.addChangeListener(new ChangeListener() {
public void stateChanged(ChangeEvent e) {
maxValue.setEnabled(fixedScaleTrueFalse.getModel().isSelected());
inputModel.fixedScaling = fixedScaleTrueFalse.getModel().isSelected();
}
});
maxValue.setEnabled(inputModel.fixedScaling);
maxValue.setModel(new SpinnerNumberModel(inputModel.max, 0, 100000, 1000));
maxValue.addChangeListener(new ChangeListener() {
public void stateChanged(ChangeEvent e) {
inputModel.max = ((Number) maxValue.getModel().getValue()).intValue();
refresh();
}
});
}
/**
* Creates a chart.
*
* @param dataset the dataset.
* @return The chart.
*/
private JFreeChart createChart(final TableXYDataset dataset) {
final JFreeChart chart = ChartFactory.createTimeSeriesChart(
"Area Chart", // chart title
"Time", // domain axis label
"Value", // range axis label
dataset, // data
true, // include legend
true, // tooltips
false // urls
);
chart.setBackgroundPaint(Color.white);
final TextTitle subtitle = new TextTitle("This chart demonstrates brownian movement/the wiener process.");
subtitle.setFont(new Font("SansSerif", Font.PLAIN, 12));
subtitle.setPosition(RectangleEdge.TOP);
subtitle.setVerticalAlignment(VerticalAlignment.BOTTOM);
chart.addSubtitle(subtitle);
return chart;
}
private void run() {
frame.setVisible(true);
}
/**
* The input values for the calculation.
*/
public static class InputModel {
double mu = 0.05;
double sigma = 0.1;
int years = 10;
int initialValue = 10000;
int monthlyValue = 100;
boolean fixedScaling;
double max;
}
}
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>biz.paluch.redis</groupId>
<artifactId>wiener-process</artifactId>
<version>0.1-SNAPSHOT</version>
<dependencies>
<dependency>
<groupId>org.jfree</groupId>
<artifactId>jfreechart</artifactId>
<version>1.0.19</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.5</version>
</dependency>
</dependencies>
</project>
import org.apache.commons.math3.distribution.NormalDistribution;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static java.lang.Math.sqrt;
public class WienerProcess {
/**
* Run the Wiener process for a given period and initial amount with a monthly value that is added every month. The
* code calculates the projection of the value, a set of quantiles and the brownian geometric motion based on a
* random walk.
*
* @param mu mean value (annualized)
* @param sigma standard deviation (annualized)
* @param years projection duration in years
* @param initialValue the initial value
* @param monthlyValue the value that is added per month
* @param breaks quantile breaks
* @return
*/
public static List<PointInTime> getProjection(double mu, double sigma, int years, int initialValue,
int monthlyValue, double[] breaks) {
double periodizedMu = mu / 12;
double periodizedSigma = sigma / Math.sqrt(12);
LocalDate projectionStart = LocalDate.now();
int periods = years * 12;
List<PointInTime> result = new ArrayList<PointInTime>();
double initialBounds[] = new double[breaks.length];
Arrays.fill(initialBounds, initialValue);
result.add(new PointInTime(projectionStart, initialValue, initialValue, initialBounds));
// Calculate value and quantiles
for (int i = 0; i < periods; i++) {
double value = initialValue + (monthlyValue * i);
NormalDistribution normalDistribution = new NormalDistribution(periodizedMu * (i + 1),
periodizedSigma * sqrt(i + 1));
double bounds[] = new double[breaks.length];
for (int j = 0; j < breaks.length; j++) {
bounds[j] =
value * Math.exp(normalDistribution.inverseCumulativeProbability(breaks[j]));
}
result.add(new PointInTime(projectionStart.plusMonths(i + 1), value, 0, bounds));
}
// Simulate a path
for (int i = 1; i < periods; i++) {
double value = monthlyValue;
NormalDistribution normalDistribution = new NormalDistribution(
periodizedMu, periodizedSigma);
double randomReturn = normalDistribution.inverseCumulativeProbability(Math.random());
value = (value + result.get(i).simulation) * Math.exp(randomReturn);
result.get(i + 1).simulation = value;
}
return result;
}
/**
* @author <a href="mailto:mpaluch@paluch.biz">Mark Paluch</a>
*/
public static class PointInTime {
LocalDate localDate;
double value;
double simulation;
double bounds[];
public PointInTime(LocalDate localDate, double value, double simulation, double[] bounds) {
this.localDate = localDate;
this.value = value;
this.simulation = simulation;
this.bounds = bounds;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment