线性回归是一种分析方法,用于估计具有一个或多个自变量的线性方程的系数,这些系数能够最好地预测因变量的值。线性回归拟合一条直线,以最小化实际值和预测值之间的差异。线性回归最适合并被广泛用于商业领域,以评估趋势并进行估计或预测。本文将介绍如何使用Python的Tkinter库创建一个交互式线性回归分析工具。
线性回归模型可以表示为:
Y = a + bX
其中,X是自变量,Y是因变量。方程中的b项代表直线的斜率,a项代表截距,即当X为零时Y的值。
程序的主要部分是使用Tkinter设计应用程序的用户界面,并声明所需的变量。以下是Python代码的示例:
from tkinter import *
from tkinter import messagebox
from tkinter.tix import *
import pandas as pd
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import os
distances = []
fares = []
data = {}
window = Tk()
window.title("Linear Regression")
window.geometry("800x500")
tip = Balloon(window)
lbldistance = Label(window, text="Enter Distance:", anchor="w")
lbldistance.place(x=50, y=50, width=100)
txtdistance = Entry(window)
txtdistance.place(x=150, y=50, width=100)
lblfare = Label(window, text="Enter Fare:", anchor="w")
lblfare.place(x=50, y=75, width=100)
txtfare = Entry(window)
txtfare.place(x=150, y=75, width=100)
btnadd = Button(window, text="Add/Update", command=add)
btnadd.place(x=50, y=100, width=100)
btndelete = Button(window, text="Delete", command=delete)
btndelete.place(x=150, y=100, width=100)
btnplot = Button(window, text="Plot", command=plot)
btnplot.place(x=50, y=125, width=100)
btnclear = Button(window, text="Clear", command=clearplot)
btnclear.place(x=150, y=125, width=100)
btnsave = Button(window, text="Save Data", command=savedata)
btnsave.place(x=50, y=150, width=100)
btnopen = Button(window, text="Open Data", command=opendata)
btnopen.place(x=150, y=150, width=100)
lstdistance = Listbox(window)
lstdistance.place(x=50, y=175, width=67)
lstfare = Listbox(window)
lstfare.place(x=120, y=175, width=67)
lstpredfare = Listbox(window)
lstpredfare.place(x=190, y=175, width=67)
lblintercept = Label(window, text="Y-Intercept:", anchor="w")
lblintercept.place(x=50, y=350, width=100)
txtintercept = Entry(window)
txtintercept.place(x=150, y=350, width=100)
lblslope = Label(window, text="Slope:", anchor="w")
lblslope.place(x=50, y=375, width=100)
txtslope = Entry(window)
txtslope.place(x=150, y=375, width=100)
lstdistance.bind("<>", listselected)
tip.bind_widget(lstdistance, balloonmsg="Distances")
tip.bind_widget(lstfare, balloonmsg="Actual Fares")
tip.bind_widget(lstpredfare, balloonmsg="Predicted Fares")
window.mainloop()
用户定义的add()函数用于添加或更新存储在列表中的distance和fare。如果distance尚未在列表中,则添加新的distance和fare;如果distance已经添加,则更新fare。然后使用updatelists()函数更新前端GUI中的数据,最后调用plot()函数绘制数据。
def add():
if txtdistance.get() in distances:
i = distances.index(txtdistance.get())
distances[i] = txtdistance.get()
fares[i] = txtfare.get()
else:
distances.append(txtdistance.get())
fares.append(txtfare.get())
updatelists()
plot()
updatelists()函数的代码如下:
def updatelists():
lstdistance.delete(0, END)
lstfare.delete(0, END)
for distance in distances:
lstdistance.insert(END, distance)
for fare in fares:
lstfare.insert(END, fare)
plot()函数用于绘制图表。数据存储为距离和票价列表的字典。模型是sklearn.linear_model包中的LinearRegression类的实例。fit()函数用于训练模型,predict()函数用于生成预测的票价。然后使用matplotlib库将实际和预测的票价绘制在距离上。
def plot():
distances = list(lstdistance.get(0, lstdistance.size()-1))
if len(distances) == 0:
return
fares = list(lstfare.get(0, lstfare.size()-1))
distances = [int(n) for n in distances]
fares = [int(n) for n in fares]
data["distances"] = distances
data["fares"] = fares
df = pd.DataFrame(data)
X = df[["distances"]]
y = df["fares"]
model = LinearRegression()
model.fit(X, y)
y_pred = model.predict(X)
lstpredfare.delete(0, END)
for n in y_pred:
lstpredfare.insert(END, n)
txtintercept.delete(0, END)
txtintercept.insert(0, str(round(model.intercept_, 2)))
txtslope.delete(0, END)
txtslope.insert(0, str(round(model.coef_[0], 2)))
clearplot()
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(X, y, color="red", marker="o", markerfacecolor="blue", label="Actual Fare")
ax.plot(X, y_pred, color="blue", marker="o", markerfacecolor="blue", label="Predicted Fare")
ax.set_title("Linear Regression Example")
ax.set_xlabel("Distance")
ax.set_ylabel("Fare")
ax.legend()
canvas = FigureCanvasTkAgg(fig, master=window)
canvas.draw()
canvas.get_tk_widget().pack()
clearplot()函数的代码如下:
def clearplot():
for widget in window.winfo_children():
if "Canvas" in str(type(widget)):
widget.destroy()
delete()函数用于从列表中删除任何distance和fare,并更新图表。
def delete():
try:
d = lstdistance.get(lstdistance.curselection())
if d in distances:
i = distances.index(d)
del distances[i]
del fares[i]
lstdistance.delete(i)
lstfare.delete(i)
lstpredfare.delete(i)
plot()
except:
pass
listselected()函数用于在屏幕上显示从List中选择的distance和fare。
def listselected(event):
if len(lstdistance.curselection()) == 0:
return
i = lstdistance.curselection()[0]
txtdistance.delete(0, END)
txtdistance.insert(END, distances[i])
txtfare.delete(0, END)
txtfare.insert(END, fares[i])
当前的distances和fares列表可以使用savedata()函数保存到CSV文件中:
def savedata():
pd.DataFrame(data).to_csv("data.csv", index=False)
保存的距离和票价可以从保存的CSV文件中使用opendata()函数加载:
def opendata():
if os.path.exists("data.csv"):
data = pd.read_csv("data.csv")
values = data.values
lstdistance.delete(0, END)
lstfare.delete(0, END)
distances.clear()
fares.clear()
for row in values:
lstdistance.insert(END, row[0])
distances.append(str(row[0]))
lstfare.insert(END, row[1])
fares.append(str(row[1]))
else:
messagebox.showerror("Error", "No data found to load")