-
-
Save ansarizafar/dd2658e65497a36bbeb11fc3dd1401ba to your computer and use it in GitHub Desktop.
Simple linear regression and prediction in PL/pgSQL (PostgreSQL)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- This code is based on my other Gist "Simple linear regression in Javascript" (https://gist.github.com/patrickpissurno/ea0dc4039f075fbaf398619761bd9044) | |
-- There might be a more efficient way to do this in SQL | |
-- This function is resposible for computing the weights for the ax + b equation | |
CREATE OR REPLACE FUNCTION linear_regression(x_array decimal(15,2)[], y_array decimal(15,2)[]) RETURNS decimal(15,2)[] AS $$ | |
DECLARE | |
slope decimal(15,2); | |
intercept decimal(15,2); | |
n integer; | |
i integer; | |
sum_x decimal(15,2) := 0; | |
sum_y decimal(15,2) := 0; | |
sum_xy decimal(15,2) := 0; | |
sum_xx decimal(20,2) := 0; | |
sum_yy decimal(20,2) := 0; | |
BEGIN | |
n := ARRAY_LENGTH(y_array, 1); | |
FOR i IN 1..n LOOP | |
sum_x := sum_x + x_array[i]; | |
sum_y := sum_y + y_array[i]; | |
sum_xy := sum_xy + (x_array[i] * y_array[i]); | |
sum_xx := sum_xx + (x_array[i] * x_array[i]); | |
sum_yy := sum_yy + (y_array[i] * y_array[i]); | |
END LOOP; | |
slope := (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x); | |
intercept := (sum_y - slope * sum_x) / n; | |
RETURN array[slope, intercept]; | |
END | |
$$ LANGUAGE plpgsql; | |
-- Usage example | |
WITH t1 AS ( | |
SELECT linear_regression(array[1,2,3], array[10,25,39]) AS weights | |
) | |
SELECT (t1.weights[1] * 4 + t1.weights[2]) AS prediction FROM t1; -- the constant 4 here is the x you want to predict | |
-- Alternative usage example | |
CREATE TABLE sales ( | |
sale_id INTEGER NOT NULL, | |
sale_date DATE NOT NULL, | |
sale_value DECIMAL(15, 2) NOT NULL, | |
PRIMARY KEY (sale_id) | |
); | |
INSERT INTO sales (sale_id, sale_date, sale_value) VALUES (1, '2019-03-11', 10), (2, '2019-04-04', 15), (3, '2019-05-11', 24); | |
-- This demo function receives a year + month number as parameter (eg. 201904, April 2019) and returns the prediction as output | |
CREATE OR REPLACE FUNCTION predict_sales(x DECIMAL(15,2)) RETURNS DECIMAL(15,2) AS $$ | |
DECLARE | |
weights DECIMAL(15,2)[]; | |
past_data RECORD; | |
BEGIN | |
WITH t1 AS ( | |
SELECT CONCAT(EXTRACT(year FROM sale_date), LPAD(EXTRACT(MONTH FROM sale_date)::text, 2, '0')) AS x, SUM(sale_value) AS y FROM sales | |
GROUP BY EXTRACT(YEAR FROM sale_date), EXTRACT(MONTH FROM sale_date) | |
ORDER BY MIN(sale_id) | |
) | |
SELECT INTO past_data ARRAY_AGG(CAST(t1.x AS DECIMAL(15,2))) AS x_array, ARRAY_AGG(CAST(t1.y AS DECIMAL(15,2))) AS y_array FROM t1; | |
weights := linear_regression(past_data.x_array, past_data.y_array); | |
RETURN x * weights[1] + weights[2]; | |
END; | |
$$ LANGUAGE plpgsql; | |
SELECT predict_sales(201906) AS prediction; -- 201906 stands for June, 2019. December, 2019 would be 201912 | |
-- returns 30.33 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment