Skip to content

Instantly share code, notes, and snippets.

@lissahyacinth
Created November 9, 2018 10:23
Show Gist options
  • Save lissahyacinth/510cb4dcd2aa60c5d8e8ad4c2b7db1e9 to your computer and use it in GitHub Desktop.
Save lissahyacinth/510cb4dcd2aa60c5d8e8ad4c2b7db1e9 to your computer and use it in GitHub Desktop.
PrestoML Example
-- Example learning sin(x)
WITH validation_data AS (
SELECT
label,
MAP(ARRAY[feature_label], ARRAY[feature]) AS features
FROM (
SELECT
TRANSFORM(SEQUENCE(1,100), X -> SIN(CAST( X AS DOUBLE))) AS n_label,
TRANSFORM(REPEAT(1, 100), X -> CAST(X AS BIGINT)) AS n_feature_label,
TRANSFORM(SEQUENCE(1,100), X -> CAST(X AS DOUBLE)) AS n_input
) CROSS JOIN UNNEST(n_label, n_feature_label, n_input) AS t(label, feature_label, feature)
)
SELECT
label,
REGRESS(features, model)
FROM (
SELECT
learn_libsvm_regressor(
CAST(output AS DOUBLE),
MAP(ARRAY[CAST(input_label AS BIGINT)], ARRAY[CAST(input AS DOUBLE)]), 'degree=0,kernel=sigmoid'
) AS model
FROM ( select * from( values('1', '0', '0'),
('1', '2.02020202020202', '0.900705446202955'),
('1', '4.04040404040404', '-0.782587502654202'),
('1', '6.06060606060606', '-0.220745974555063'),
('1', '8.08080808080808', '0.974384989475536'),
('1', '10.1010101010101', '-0.625858782585018'),
('1', '12.1212121212121', '-0.430600932498663'),
('1', '14.1414141414141', '0.999990980658534'),
('1', '16.1616161616162', '-0.438251862307188'),
('1', '18.1818181818182', '-0.619211190881117'),
('1', '20.2020202020202', '0.97626008855921'),
('1', '22.2222222222222', '-0.229022766032661'),
('1', '24.2424242424242', '-0.777271223469046'),
('1', '26.2626262626263', '0.904363131991268'),
('1', '28.2828282828283', '-0.00849429836849447'),
('1', '30.3030303030303', '-0.896982770547887'),
('1', '32.3232323232323', '0.787847314702699'),
('1', '34.3434343434343', '0.21245325528272'),
('1', '36.3636363636364', '-0.972439584221402'),
('1', '38.3838383838384', '0.63246121582001'),
('1', '40.4040404040404', '0.422918932935483'),
('1', '42.4242424242424', '-0.999918826902981'),
('1', '44.4444444444444', '0.445871170312764'),
('1', '46.4646464646465', '0.612518920361339'),
('1', '48.4848484848485', '-0.978064746175762'),
('1', '50.5050505050505', '0.237283032508533'),
('1', '52.5252525252525', '0.771898860740196'),
('1', '54.5454545454545', '-0.907955563994678'),
('1', '56.5656565656566', '0.0169879838359329'),
('1', '58.5858585858586', '0.893195373633514'),
('1', '60.6060606060606', '-0.793050280095917'),
('1', '62.6262626262626', '-0.204145206571869'),
('1', '64.6464646464647', '0.97042401316637'),
('1', '66.6666666666667', '-0.639018014191441'),
('1', '68.6868686868687', '-0.415206417907771'),
('1', '70.7070707070707', '0.999774524598088'),
('1', '72.7272727272727', '-0.45345830674874'),
('1', '74.7474747474748', '-0.605782453902495'),
('1', '76.7676767676768', '0.979798832111196'),
('1', '78.7878787878788', '-0.24552617796807'),
('1', '80.8080808080808', '-0.766470802107292'),
('1', '82.8282828282828', '0.91148248300339'),
('1', '84.8484848484849', '-0.0254804435454899'),
('1', '86.8686868686869', '-0.889343528737211'),
('1', '88.8888888888889', '0.798196023416984'),
('1', '90.9090909090909', '0.195822427884824'),
('1', '92.9292929292929', '-0.96833842174277'),
('1', '94.949494949495', '0.645528704597425'),
('1', '96.969696969697', '0.407463943907474'),
('1', '98.989898989899', '-0.999558084155901'),
('1', '101.010101010101', '0.461012724169788'),
('1', '103.030303030303', '0.599002277570326'),
('1', '105.050505050505', '-0.981462221243566'),
('1', '107.070707070707', '0.253751607631983'),
('1', '109.090909090909', '0.760987439228696'),
('1', '111.111111111111', '-0.914943634534649'),
('1', '113.131313131313', '0.0339710647287554'),
('1', '115.151515151515', '0.885427513786567'),
('1', '117.171717171717', '-0.803284173377832'),
('1', '119.191919191919', '-0.187485519746762'),
('1', '121.212121212121', '0.966182960435219'),
('1', '123.232323232323', '-0.651992817262944'),
('1', '125.252525252525', '-0.399692069588187'),
('1', '127.272727272727', '0.999269521193551'),
('1', '129.292929292929', '-0.468533877491422'),
('1', '131.313131313131', '-0.592178880584412'),
('1', '133.333333333333', '0.983054793552022'),
('1', '135.353535353535', '-0.261958727999303'),
('1', '137.373737373737', '-0.755449167753183'),
('1', '139.393939393939', '0.91833876885113'),
('1', '141.414141414141', '-0.042459234750027'),
('1', '143.434343434343', '-0.881447611339305'),
('1', '145.454545454545', '0.808314362846038'),
('1', '147.474747474747', '0.179135083702311'),
('1', '149.49494949495', '-0.96395778476974'),
('1', '151.515151515152', '0.658409885773808'),
('1', '153.535353535354', '0.391891355724901'),
('1', '155.555555555556', '-0.998908856532128'),
('1', '157.575757575758', '0.476021224029276'),
('1', '159.59595959596', '0.585312755282924'),
('1', '161.616161616162', '-0.984576434125452'),
('1', '163.636363636364', '0.270146946890119'),
('1', '165.656565656566', '0.749856387291453'),
('1', '167.676767676768', '-0.921667640978929'),
('1', '169.69696969697', '0.0509443411504216'),
('1', '171.717171717172', '0.877404108562929'),
('1', '173.737373737374', '-0.813286228871258'),
('1', '175.757575757576', '-0.170771722272242'),
('1', '177.777777777778', '0.961663055302567'),
('1', '179.79797979798', '-0.664779447110237'),
('1', '181.818181818182', '-0.384062365173494'),
('1', '183.838383838384', '0.998476116195177'),
('1', '185.858585858586', '-0.483474223538306'),
('1', '187.878787878788', '-0.578404397087056'),
('1', '189.89898989899', '0.986027033170783'),
('1', '191.919191919192', '-0.278315673488358'),
('1', '193.939393939394', '-0.744209501387278'),
('1', '195.959595959596', '0.924930010725241'),
('1', '197.979797979798', '-0.0594257716920959'),
('1', '200', '-0.873297297213995')) AS t(input_label, input, output)) ) CROSS JOIN validation_data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment