预测GDP应用:Numpy 线性回归+Matplotlib 作图

   日期:2020-09-13     浏览:98    评论:0    
核心提示:预测GDP应用:Numpy 线性回归+Matplotlib 作图需求通过2000~2019年中美两国的GDP数据,预测后续几年GDP的发展趋势:读取.csv文件,并将字符串调整为浮点型进行二阶线性回归模拟支持数据可视化保命声明:用线性回归预测GDP发展并不合理,只是作为python学习参考。如果想要了解更有意义的GDP对比可以参考b站翟老师的:https://b23.tv/6aYFVf成品效果原数据格式.csv 文件(“testgdp.csv”),gdp数据每千位均被 “,”

预测GDP应用:Numpy 线性回归+Matplotlib 作图

需求

通过2000~2019年中美两国的GDP数据,预测后续几年GDP的发展趋势:

  • 读取.csv文件,并将字符串调整为浮点型
  • 进行二阶线性回归模拟
  • 支持数据可视化

保命声明:用线性回归预测GDP发展并不合理,只是作为python学习参考。

如果想要了解更有意义的GDP对比可以参考b站翟老师的:https://b23.tv/6aYFVf

成品效果

原数据格式

.csv 文件(“testgdp.csv”),gdp数据每千位均被 “,” 隔开

需求拆解

1、csv文件读取

以测试文件"testgdp.csv"为例,目标将csv数据读取成适合进行线性回归的格式ndarray

方法一pandasread_csv()函数

import pandas as pd
import numpy as np
data = pd.read_csv("testgdp.csv")
df = pd.DataFrame(data)
print(df.head())

years = np.array(df.years) #可以转化为 ndarray
years

方法二: python自带的 open()函数

import csv
import numpy as np
data_list = []
with open("testgdp.csv",encoding = 'utf-8') as csvfile:
    csv_reader = csv.reader(csvfile)
    for row in csv_reader:
        data_list.append(row[0:3])#第3~7列为空数据,需要排除
    data1 = np.array(data_list)
    data2 = np.delete(data1,-1,axis=0)#删除最后一行空值行,axis=1时可删除列
data2

2、对“xxx,xxx,xxx”格式字符串转化为数字

split():用指定分隔符对 字符串 进行切片,变为 list

strr.split (str="", num=string.count(str))

  • strr 为原字符串
  • str 为分隔符号
  • num – 分割次数。默认为 -1, 即分隔所有
def intt(list,exc_rate=1):#将"xxx,xxx,xxx,xxx,xxx"格式的str转化为 整型,exc_rate为汇率
    list_new = []
    for strr in list:
        int_list = strr.split(',') # 分割str,转化为列表
        lenth = len(int_list)
        result = 0
        for n in range(lenth):
            ii = int(int_list[n])
            result = result + ii*1000**(lenth-n-1)*exc_rate
        list_new.append(result)
    return list_new

list = ['11,061,552,790,044','14,342,902,842,915','234,322,342,111','123,212,231']
intt(list)

3、线性回归:np.polyfit()多项式拟合、np.polyval()多项式曲线求值

P = np.polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False)

  • x, y:一般是array格式的数组,分别代表自变量和因变量
  • deg:阶数(需要整型),即需要进行几阶线性回归
  • 其他数据不太常用,可以不输入,即使用默认参数。如果需要了解可以参考:numpy.polyfit

输出参数 P为拟合多项式
P ( 1 ) x n + P ( 2 ) x n − 1 + . . . + P ( n ) x + P ( n + 1 ) 的 系 数 组 合 P(1)x^n + P(2)x^{n-1} +...+ P(n)x + P(n+1) 的 系数组合 P(1)xn+P(2)xn1+...+P(n)x+P(n+1)
P 为[ 1, 2, 3]时,代表多项式线性回归的结果为
Y = x 2 + 2 x + 3 Y = x^2+2x+3 Y=x2+2x+3
可以用np.polyval()方法输出预测结果Y,即

Y = np.polyval(P, x)

4、模块输出可视化图表

要用到matplotlib.pyplot,这个模块内容非常非常多,现在根据需求选取几个易用的函数

官方文档:https://matplotlib.org/api/pyplot_api.html

功能一:绘制关系曲线

绘制一条x,y关系曲线,红色,宽度为2,标签为label

plt.plot(x, y, color="red”,linestyle="-", linewidth=2.0, label=‘label')
  • x, y:与前面的x, y相同,支持array格式的数组,分别代表自变量和因变量
  • 设置label标签有助于后续生成图例
import matplotlib.pyplot as plt 
x=[1,2,3,5] 
y=[2,3,5,9] 
plt.plot(x, y,color="red",linestyle="-", linewidth=2.0,label='label1') 
plt.show()

功能二:新增图例

plt.legend(loc=*'best'*,label=lable_list)

loc=‘best’时图例自动‘安家’在一个坐标面内的数据图表最少的位置,可以设置为指定位置。

参考链接:https://zhuanlan.zhihu.com/p/111108841

功能三:箭头标注关键信息

对第三个坐标点用红色箭头标注,箭头离坐标相差0.05个单位。同时在(4,2)提醒’this is the annotate’.

plt.annotate('this is the annotate', xy=(x[2],y[2]), xycoords='data', xytext=(4,2),
arrowprops=dict(facecolor='red', shrink=0.05))

可以参考https://blog.csdn.net/wizardforcel/article/details/54782628

实例代码

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
def intt(list,exc_rate=1):#将"xxx,xxx,xxx,xxx,xxx"格式的str转化为 整型,exc_rate为汇率
    list_new = []
    for strr in list:
        int_list = strr.split(',') # 分割str,转化为列表
        lenth = len(int_list)
        result = 0
        for n in range(lenth):
            ii = int(int_list[n])
            result = result + ii*1000**(lenth-n-1)*exc_rate
        list_new.append(result)
    return list_new

  def pre(n):#n为预测时间(年)
      data = pd.read_csv("testgdp.csv")
      df = pd.DataFrame(data)
      df = df.drop([19])#删除空行
      years = np.array(df.years)
      cn = intt(np.array(df.cn))
      usa = intt(np.array(df.us))
      model_cn = np.polyfit(years,cn,2)#阶线性回归cn
      model_usa = np.polyfit(years,usa,2)#2阶维线性回归usa
      overyear_list = []
      overusa_list = []
      overcn_list = []
    for i in range(n):#预测n年后gdp数据表现
        yy=2020+i
        cn_gdp=np.polyval(model_cn,yy)
        usa_gdp=np.polyval(model_usa,yy)
        if cn_gdp>usa_gdp:#判断何时中国gdp超过美国,并记录下来
            overyear_list.append(yy)
            overusa_list.append(usa_gdp)
            overcn_list.append(cn_gdp)
        cn = np.append(cn,cn_gdp)
        usa = np.append(usa,usa_gdp)
        years=np.append(years,yy)
    plt.plot(years, cn,color="red",linestyle="-", linewidth=2.0,label='CN')
    plt.plot(overyear_list, overcn_list, color="red", linestyle="-", linewidth=4.0)#加粗超过美国的部分
    plt.plot(years, usa,color="blue",
             linestyle="-", linewidth=2.0,label='USA')
    plt.plot(years[0:len(years)-len(overyear_list)+1],
             usa[0:len(years)-len(overyear_list)+1],
             color="blue", linestyle="-", linewidth=4.0)
    plt.legend(loc='upper left')#图例,位置左上
    plt.annotate(s=("%d:CN%.1ftrillion ,USA%.1ftrillion"%(overyear_list[0],overcn_list[0]/(10**12),overusa_list[0]/(10**12))),xy=(overyear_list[0],overcn_list[0]),
                 xytext=(overyear_list[0]+n/10,overcn_list[0]*0.6)
                 ,arrowprops=dict(facecolor='red', shrink=0.05))#arrowprops箭头
    plt.show()
pre(40)

后续进阶

  • 增加爬虫功能(合法的那种!)
  • 优化可视化图表(增加图表样式,增加图像交互能力,如调用Pyecharts
  • 增加更多纬度数据,采用逻辑回归
  • 增加与数据库对接的功能
 
打赏
 本文转载自:网络 
所有权利归属于原作者,如文章来源标示错误或侵犯了您的权利请联系微信13520258486
更多>最近资讯中心
更多>最新资讯中心
0相关评论

推荐图文
推荐资讯中心
点击排行
最新信息
新手指南
采购商服务
供应商服务
交易安全
关注我们
手机网站:
新浪微博:
微信关注:

13520258486

周一至周五 9:00-18:00
(其他时间联系在线客服)

24小时在线客服