Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ansarizafar/dd2658e65497a36bbeb11fc3dd1401ba to your computer and use it in GitHub Desktop.
Save ansarizafar/dd2658e65497a36bbeb11fc3dd1401ba to your computer and use it in GitHub Desktop.
Simple linear regression and prediction in PL/pgSQL (PostgreSQL)
-- This code is based on my other Gist "Simple linear regression in Javascript" (https://gist.github.com/patrickpissurno/ea0dc4039f075fbaf398619761bd9044)
-- There might be a more efficient way to do this in SQL
-- This function is resposible for computing the weights for the ax + b equation
CREATE OR REPLACE FUNCTION linear_regression(x_array decimal(15,2)[], y_array decimal(15,2)[]) RETURNS decimal(15,2)[] AS $$
DECLARE
slope decimal(15,2);
intercept decimal(15,2);
n integer;
i integer;
sum_x decimal(15,2) := 0;
sum_y decimal(15,2) := 0;
sum_xy decimal(15,2) := 0;
sum_xx decimal(20,2) := 0;
sum_yy decimal(20,2) := 0;
BEGIN
n := ARRAY_LENGTH(y_array, 1);
FOR i IN 1..n LOOP
sum_x := sum_x + x_array[i];
sum_y := sum_y + y_array[i];
sum_xy := sum_xy + (x_array[i] * y_array[i]);
sum_xx := sum_xx + (x_array[i] * x_array[i]);
sum_yy := sum_yy + (y_array[i] * y_array[i]);
END LOOP;
slope := (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x);
intercept := (sum_y - slope * sum_x) / n;
RETURN array[slope, intercept];
END
$$ LANGUAGE plpgsql;
-- Usage example
WITH t1 AS (
SELECT linear_regression(array[1,2,3], array[10,25,39]) AS weights
)
SELECT (t1.weights[1] * 4 + t1.weights[2]) AS prediction FROM t1; -- the constant 4 here is the x you want to predict
-- Alternative usage example
CREATE TABLE sales (
sale_id INTEGER NOT NULL,
sale_date DATE NOT NULL,
sale_value DECIMAL(15, 2) NOT NULL,
PRIMARY KEY (sale_id)
);
INSERT INTO sales (sale_id, sale_date, sale_value) VALUES (1, '2019-03-11', 10), (2, '2019-04-04', 15), (3, '2019-05-11', 24);
-- This demo function receives a year + month number as parameter (eg. 201904, April 2019) and returns the prediction as output
CREATE OR REPLACE FUNCTION predict_sales(x DECIMAL(15,2)) RETURNS DECIMAL(15,2) AS $$
DECLARE
weights DECIMAL(15,2)[];
past_data RECORD;
BEGIN
WITH t1 AS (
SELECT CONCAT(EXTRACT(year FROM sale_date), LPAD(EXTRACT(MONTH FROM sale_date)::text, 2, '0')) AS x, SUM(sale_value) AS y FROM sales
GROUP BY EXTRACT(YEAR FROM sale_date), EXTRACT(MONTH FROM sale_date)
ORDER BY MIN(sale_id)
)
SELECT INTO past_data ARRAY_AGG(CAST(t1.x AS DECIMAL(15,2))) AS x_array, ARRAY_AGG(CAST(t1.y AS DECIMAL(15,2))) AS y_array FROM t1;
weights := linear_regression(past_data.x_array, past_data.y_array);
RETURN x * weights[1] + weights[2];
END;
$$ LANGUAGE plpgsql;
SELECT predict_sales(201906) AS prediction; -- 201906 stands for June, 2019. December, 2019 would be 201912
-- returns 30.33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment