交互式线性回归分析工具

线性回归是一种分析方法,用于估计具有一个或多个自变量的线性方程的系数,这些系数能够最好地预测因变量的值。线性回归拟合一条直线,以最小化实际值和预测值之间的差异。线性回归最适合并被广泛用于商业领域,以评估趋势并进行估计或预测。本文将介绍如何使用Python的Tkinter库创建一个交互式线性回归分析工具。

线性回归模型可以表示为:

Y = a + bX

其中,X是自变量,Y是因变量。方程中的b项代表直线的斜率,a项代表截距,即当X为零时Y的值。

使用代码

程序的主要部分是使用Tkinter设计应用程序的用户界面,并声明所需的变量。以下是Python代码的示例:

from tkinter import * from tkinter import messagebox from tkinter.tix import * import pandas as pd from sklearn.linear_model import LinearRegression import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg import os distances = [] fares = [] data = {} window = Tk() window.title("Linear Regression") window.geometry("800x500") tip = Balloon(window) lbldistance = Label(window, text="Enter Distance:", anchor="w") lbldistance.place(x=50, y=50, width=100) txtdistance = Entry(window) txtdistance.place(x=150, y=50, width=100) lblfare = Label(window, text="Enter Fare:", anchor="w") lblfare.place(x=50, y=75, width=100) txtfare = Entry(window) txtfare.place(x=150, y=75, width=100) btnadd = Button(window, text="Add/Update", command=add) btnadd.place(x=50, y=100, width=100) btndelete = Button(window, text="Delete", command=delete) btndelete.place(x=150, y=100, width=100) btnplot = Button(window, text="Plot", command=plot) btnplot.place(x=50, y=125, width=100) btnclear = Button(window, text="Clear", command=clearplot) btnclear.place(x=150, y=125, width=100) btnsave = Button(window, text="Save Data", command=savedata) btnsave.place(x=50, y=150, width=100) btnopen = Button(window, text="Open Data", command=opendata) btnopen.place(x=150, y=150, width=100) lstdistance = Listbox(window) lstdistance.place(x=50, y=175, width=67) lstfare = Listbox(window) lstfare.place(x=120, y=175, width=67) lstpredfare = Listbox(window) lstpredfare.place(x=190, y=175, width=67) lblintercept = Label(window, text="Y-Intercept:", anchor="w") lblintercept.place(x=50, y=350, width=100) txtintercept = Entry(window) txtintercept.place(x=150, y=350, width=100) lblslope = Label(window, text="Slope:", anchor="w") lblslope.place(x=50, y=375, width=100) txtslope = Entry(window) txtslope.place(x=150, y=375, width=100) lstdistance.bind("<>", listselected) tip.bind_widget(lstdistance, balloonmsg="Distances") tip.bind_widget(lstfare, balloonmsg="Actual Fares") tip.bind_widget(lstpredfare, balloonmsg="Predicted Fares") window.mainloop()

添加和更新数据

用户定义的add()函数用于添加或更新存储在列表中的distance和fare。如果distance尚未在列表中,则添加新的distance和fare;如果distance已经添加,则更新fare。然后使用updatelists()函数更新前端GUI中的数据,最后调用plot()函数绘制数据。

def add(): if txtdistance.get() in distances: i = distances.index(txtdistance.get()) distances[i] = txtdistance.get() fares[i] = txtfare.get() else: distances.append(txtdistance.get()) fares.append(txtfare.get()) updatelists() plot()

更新列表

updatelists()函数的代码如下:

def updatelists(): lstdistance.delete(0, END) lstfare.delete(0, END) for distance in distances: lstdistance.insert(END, distance) for fare in fares: lstfare.insert(END, fare)

绘制图表

plot()函数用于绘制图表。数据存储为距离和票价列表的字典。模型是sklearn.linear_model包中的LinearRegression类的实例。fit()函数用于训练模型,predict()函数用于生成预测的票价。然后使用matplotlib库将实际和预测的票价绘制在距离上。

def plot(): distances = list(lstdistance.get(0, lstdistance.size()-1)) if len(distances) == 0: return fares = list(lstfare.get(0, lstfare.size()-1)) distances = [int(n) for n in distances] fares = [int(n) for n in fares] data["distances"] = distances data["fares"] = fares df = pd.DataFrame(data) X = df[["distances"]] y = df["fares"] model = LinearRegression() model.fit(X, y) y_pred = model.predict(X) lstpredfare.delete(0, END) for n in y_pred: lstpredfare.insert(END, n) txtintercept.delete(0, END) txtintercept.insert(0, str(round(model.intercept_, 2))) txtslope.delete(0, END) txtslope.insert(0, str(round(model.coef_[0], 2))) clearplot() fig = plt.figure() ax = fig.add_subplot(111) ax.plot(X, y, color="red", marker="o", markerfacecolor="blue", label="Actual Fare") ax.plot(X, y_pred, color="blue", marker="o", markerfacecolor="blue", label="Predicted Fare") ax.set_title("Linear Regression Example") ax.set_xlabel("Distance") ax.set_ylabel("Fare") ax.legend() canvas = FigureCanvasTkAgg(fig, master=window) canvas.draw() canvas.get_tk_widget().pack()

清除图表

clearplot()函数的代码如下:

def clearplot(): for widget in window.winfo_children(): if "Canvas" in str(type(widget)): widget.destroy()

删除数据

delete()函数用于从列表中删除任何distance和fare,并更新图表。

def delete(): try: d = lstdistance.get(lstdistance.curselection()) if d in distances: i = distances.index(d) del distances[i] del fares[i] lstdistance.delete(i) lstfare.delete(i) lstpredfare.delete(i) plot() except: pass

选择列表

listselected()函数用于在屏幕上显示从List中选择的distance和fare。

def listselected(event): if len(lstdistance.curselection()) == 0: return i = lstdistance.curselection()[0] txtdistance.delete(0, END) txtdistance.insert(END, distances[i]) txtfare.delete(0, END) txtfare.insert(END, fares[i])

保存和打开数据

当前的distances和fares列表可以使用savedata()函数保存到CSV文件中:

def savedata(): pd.DataFrame(data).to_csv("data.csv", index=False)

保存的距离和票价可以从保存的CSV文件中使用opendata()函数加载:

def opendata(): if os.path.exists("data.csv"): data = pd.read_csv("data.csv") values = data.values lstdistance.delete(0, END) lstfare.delete(0, END) distances.clear() fares.clear() for row in values: lstdistance.insert(END, row[0]) distances.append(str(row[0])) lstfare.insert(END, row[1]) fares.append(str(row[1])) else: messagebox.showerror("Error", "No data found to load")
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485