Skip to content

Instantly share code, notes, and snippets.

@pdet
Last active May 17, 2023 08:56
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pdet/364f9b72e66cbae4b43303d58448f74f to your computer and use it in GitHub Desktop.
Save pdet/364f9b72e66cbae4b43303d58448f74f to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import duckdb
import pyarrow as pa
import matplotlib.pyplot as plt
class LinearRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
out = self.linear(x)
return out
def train_linear_regression(X, y, learning_rate=0.01, epochs=1000):
# Define the input and output dimensions
input_dim = 1
output_dim = 1
# Create a linear regression model instance
model = LinearRegression(input_dim, output_dim)
# Define the loss function
criterion = nn.MSELoss()
# Define the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Train the model
for epoch in range(epochs):
# Forward pass
y_pred = model(X)
# Compute loss
loss = criterion(y_pred, y)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print progress
if (epoch+1) % 100 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
# Return the trained model
return model
def predict_linear_regression(model, X):
with torch.no_grad():
y_pred = model(X)
return y_pred.numpy()
data = duckdb.execute("SELECT trip_distance as trip_distance, fare_amount as fare_amount from 'yellow_tripdata_2016-01.parquet' limit 10000").fetchnumpy()
X = torch.tensor(data['trip_distance'], dtype=torch.float32).reshape(-1, 1)
y = torch.tensor(data['fare_amount'], dtype=torch.float32).reshape(-1, 1)
# Train the linear regression model
model = train_linear_regression(X, y)
# Predict on new data
X_test = torch.tensor([2.5, 4.2, 0.8], dtype=torch.float32).reshape(-1, 1)
def predict_fare(x):
global model
tensor_list = [torch.from_numpy(chunk.to_numpy()).float() for chunk in x.chunks]
tensor = torch.stack(tensor_list, dim=1)
predicted = predict_linear_regression(model, tensor).flatten()
schema = pa.schema([('predicted_value', pa.float32())])
batch = pa.record_batch([predicted], names=schema.names)
table = pa.Table.from_batches([batch])
return table
con = duckdb.connect()
con.create_function('predict_fare', predict_fare, ['DOUBLE'], 'DOUBLE', type='arrow')
duck_df_sample = con.sql("SELECT predict_fare(trip_distance) as predicted_fare, fare_amount, trip_distance from 'yellow_tripdata_2016-02.parquet' LIMIT 100").df()
ax = duck_df_sample.plot(kind='scatter', x='trip_distance', y='predicted_fare', c='blue', alpha=0.5)
duck_df_sample.plot(kind='scatter', x='trip_distance', y='fare_amount', c='red', alpha=0.5, ax=ax)
# set the x-axis label
ax.set_xlabel('Trip Distance')
# set the y-axis label
ax.set_ylabel('Fare')
# set the plot title
ax.set_title('Predicted Fare vs Actual Fare by Trip Distance')
# show the plot
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment