Last active
May 17, 2023 08:56
-
-
Save pdet/364f9b72e66cbae4b43303d58448f74f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import duckdb | |
import pyarrow as pa | |
import matplotlib.pyplot as plt | |
class LinearRegression(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super(LinearRegression, self).__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, x): | |
out = self.linear(x) | |
return out | |
def train_linear_regression(X, y, learning_rate=0.01, epochs=1000): | |
# Define the input and output dimensions | |
input_dim = 1 | |
output_dim = 1 | |
# Create a linear regression model instance | |
model = LinearRegression(input_dim, output_dim) | |
# Define the loss function | |
criterion = nn.MSELoss() | |
# Define the optimizer | |
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) | |
# Train the model | |
for epoch in range(epochs): | |
# Forward pass | |
y_pred = model(X) | |
# Compute loss | |
loss = criterion(y_pred, y) | |
# Backward pass and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# Print progress | |
if (epoch+1) % 100 == 0: | |
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item())) | |
# Return the trained model | |
return model | |
def predict_linear_regression(model, X): | |
with torch.no_grad(): | |
y_pred = model(X) | |
return y_pred.numpy() | |
data = duckdb.execute("SELECT trip_distance as trip_distance, fare_amount as fare_amount from 'yellow_tripdata_2016-01.parquet' limit 10000").fetchnumpy() | |
X = torch.tensor(data['trip_distance'], dtype=torch.float32).reshape(-1, 1) | |
y = torch.tensor(data['fare_amount'], dtype=torch.float32).reshape(-1, 1) | |
# Train the linear regression model | |
model = train_linear_regression(X, y) | |
# Predict on new data | |
X_test = torch.tensor([2.5, 4.2, 0.8], dtype=torch.float32).reshape(-1, 1) | |
def predict_fare(x): | |
global model | |
tensor_list = [torch.from_numpy(chunk.to_numpy()).float() for chunk in x.chunks] | |
tensor = torch.stack(tensor_list, dim=1) | |
predicted = predict_linear_regression(model, tensor).flatten() | |
schema = pa.schema([('predicted_value', pa.float32())]) | |
batch = pa.record_batch([predicted], names=schema.names) | |
table = pa.Table.from_batches([batch]) | |
return table | |
con = duckdb.connect() | |
con.create_function('predict_fare', predict_fare, ['DOUBLE'], 'DOUBLE', type='arrow') | |
duck_df_sample = con.sql("SELECT predict_fare(trip_distance) as predicted_fare, fare_amount, trip_distance from 'yellow_tripdata_2016-02.parquet' LIMIT 100").df() | |
ax = duck_df_sample.plot(kind='scatter', x='trip_distance', y='predicted_fare', c='blue', alpha=0.5) | |
duck_df_sample.plot(kind='scatter', x='trip_distance', y='fare_amount', c='red', alpha=0.5, ax=ax) | |
# set the x-axis label | |
ax.set_xlabel('Trip Distance') | |
# set the y-axis label | |
ax.set_ylabel('Fare') | |
# set the plot title | |
ax.set_title('Predicted Fare vs Actual Fare by Trip Distance') | |
# show the plot | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment