2022-12-07 10:43:52 +08:00
|
|
|
# -*-coding:utf-8-*-
|
|
|
|
from flask import Flask, request, make_response
|
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
import json
|
|
|
|
from logzero import logger
|
2022-12-12 09:23:45 +08:00
|
|
|
from flask import stream_with_context
|
2022-12-07 10:43:52 +08:00
|
|
|
|
2022-12-08 17:39:29 +08:00
|
|
|
import io
|
2022-12-12 09:23:45 +08:00
|
|
|
import csv
|
2022-12-09 18:43:09 +08:00
|
|
|
from house_price.house_price_predcition import run_boston_price
|
|
|
|
from ocean_wave.wave_height_mlp import predict_wave_height
|
2022-12-07 10:43:52 +08:00
|
|
|
from prophet_predict.prophet_predict import run_prophet
|
|
|
|
|
2022-12-08 16:22:29 +08:00
|
|
|
TEXT = "text"
|
2022-12-07 10:43:52 +08:00
|
|
|
app = Flask(__name__)
|
|
|
|
|
2022-12-12 09:23:45 +08:00
|
|
|
|
2022-12-12 09:51:49 +08:00
|
|
|
def generate(data: pd.DataFrame):
|
2022-12-12 09:23:45 +08:00
|
|
|
# 用 StringIO 在内存中写,不会生成实际文件
|
|
|
|
out = io.StringIO()
|
|
|
|
w = csv.writer(out)
|
2022-12-12 09:51:49 +08:00
|
|
|
w.writerow(data.columns.tolist()) # 先写入表头
|
|
|
|
yield out.getvalue()
|
|
|
|
out.seek(0)
|
|
|
|
out.truncate(0)
|
2022-12-12 09:23:45 +08:00
|
|
|
for i in range(data.shape[0]): # 对于 data 中的每一条
|
|
|
|
w.writerow(data.iloc[i].values.tolist()) # 传入的是一个数组 ['xxx','xxx@xxx.xxx'] csv.writer 会把它处理成逗号分隔的一行
|
|
|
|
# 需要注意的是传入仅一个字符串 '' 时,会被逐字符分割,所以要写成 ['xxx'] 的形式
|
|
|
|
yield out.getvalue() # 返回写入的值
|
|
|
|
out.seek(0) # io流的指针回到起点
|
|
|
|
out.truncate(0) # 删去指针之后的部分,即清空所有写入的内容,准备下一行的写入
|
|
|
|
|
|
|
|
|
2022-12-09 18:27:43 +08:00
|
|
|
@app.route('/house_price/', methods=["POST"])
|
2022-12-09 10:27:54 +08:00
|
|
|
def predict_price():
|
|
|
|
resp_info = dict()
|
|
|
|
if request.method == 'POST':
|
|
|
|
eta = request.form.get('eta', 0.05)
|
|
|
|
max_depth = request.form.get('max_depth', 10)
|
|
|
|
subsample = request.form.get('subsample', 0.7)
|
|
|
|
cosample_bytree = request.form.get('cosample_bytree', 0.8)
|
|
|
|
num_boost_round = int(request.form.get('num_boost_round', 1000))
|
|
|
|
early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200))
|
|
|
|
train_data = request.files.get('train_data', None)
|
|
|
|
test_data = request.files.get('test_data', None)
|
|
|
|
logger.info(train_data)
|
|
|
|
params = {
|
|
|
|
"eta": float(eta),
|
|
|
|
"max_depth": int(max_depth),
|
|
|
|
"subsample": float(subsample),
|
|
|
|
"cosample_bytree": float(cosample_bytree)
|
|
|
|
}
|
|
|
|
if not train_data:
|
|
|
|
train_data = None
|
|
|
|
else:
|
|
|
|
train_data = pd.read_csv(train_data)
|
2022-12-12 10:37:51 +08:00
|
|
|
if test_data is None:
|
2022-12-09 10:27:54 +08:00
|
|
|
resp_info["msg"] = "测试数据为空"
|
|
|
|
resp_info["code"] = 406
|
|
|
|
else:
|
|
|
|
test_data = pd.read_csv(test_data)
|
2022-12-12 10:37:51 +08:00
|
|
|
if test_data.shape[0] == 0:
|
|
|
|
resp_info["msg"] = "测试数据为空"
|
2022-12-09 10:27:54 +08:00
|
|
|
resp_info["code"] = 406
|
|
|
|
else:
|
2022-12-12 10:37:51 +08:00
|
|
|
try:
|
|
|
|
if train_data is None:
|
|
|
|
rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params)
|
|
|
|
else:
|
|
|
|
rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params)
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error: {e}")
|
|
|
|
resp_info["msg"] = str(e)
|
|
|
|
resp_info["code"] = 406
|
|
|
|
else:
|
|
|
|
resp_info["code"] = 200
|
2022-12-09 10:27:54 +08:00
|
|
|
if resp_info["code"] == 200:
|
2022-12-12 09:23:45 +08:00
|
|
|
resp = make_response(stream_with_context(generate(rst)))
|
|
|
|
resp.headers["Content-Disposition"] = "attachment; filename=house_price.csv"
|
|
|
|
resp.headers["Content-type"] = "text/csv"
|
2022-12-09 10:27:54 +08:00
|
|
|
else:
|
2022-12-09 11:45:38 +08:00
|
|
|
resp = make_response(json.dumps(resp_info))
|
2022-12-09 10:27:54 +08:00
|
|
|
resp.status_code = 200
|
|
|
|
return resp
|
|
|
|
|
|
|
|
|
2022-12-09 18:27:43 +08:00
|
|
|
@app.route('/ocean_wave_height/', methods=["POST"])
|
2022-12-09 10:27:54 +08:00
|
|
|
def predict_height():
|
|
|
|
resp_info = dict()
|
|
|
|
if request.method == 'POST':
|
|
|
|
num_units = int(request.form.get('num_units', 8))
|
|
|
|
activation = request.form.get('activation', 'relu')
|
|
|
|
lr = float(request.form.get('learning_rate', 0.01))
|
|
|
|
loss = request.form.get('loss', 'mae')
|
|
|
|
epochs = int(request.form.get('num_boost_round', 100))
|
|
|
|
train_data = request.files.get('train_data', None)
|
|
|
|
WVHT_1 = float(request.form.get("WVHT_1", None))
|
|
|
|
WDIR_1 = float(request.form.get("WDIR_1", None))
|
|
|
|
WSPD_1 = float(request.form.get("WSPD_1", None))
|
|
|
|
WDIR_2 = float(request.form.get("WDIR_2", None))
|
|
|
|
WSPD_2 = float(request.form.get("WSPD_2", None))
|
|
|
|
WDIR = float(request.form.get("WDIR", None))
|
|
|
|
WSPD = float(request.form.get("WSPD", None))
|
|
|
|
x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD]
|
|
|
|
x_test = np.array([x_test])
|
|
|
|
logger.info(f"test data: {x_test}")
|
|
|
|
if not train_data:
|
|
|
|
train_data = None
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
train_data = pd.read_csv(train_data)
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error: {e}")
|
|
|
|
resp_info["msg"] = str(e)
|
|
|
|
resp_info["code"] = 406
|
|
|
|
train_data = None
|
|
|
|
try:
|
|
|
|
rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test)
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error: {e}")
|
|
|
|
resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查"
|
|
|
|
resp_info["code"] = 406
|
|
|
|
else:
|
|
|
|
resp_info["code"] = 200
|
2022-12-12 10:37:51 +08:00
|
|
|
resp_info["data"] = str(rst)
|
2022-12-09 10:27:54 +08:00
|
|
|
resp_info["dtype"] = TEXT
|
|
|
|
resp = make_response(json.dumps(resp_info))
|
|
|
|
resp.status_code = 200
|
|
|
|
return resp
|
2022-12-09 18:43:09 +08:00
|
|
|
|
2022-12-07 10:43:52 +08:00
|
|
|
|
|
|
|
@app.route("/prophet/", methods=["POST"])
|
|
|
|
def run_ts_predict():
|
2022-12-08 16:22:29 +08:00
|
|
|
resp_info = dict()
|
2022-12-12 09:23:45 +08:00
|
|
|
file_name = "result.csv"
|
2022-12-07 10:43:52 +08:00
|
|
|
if request.method == "POST":
|
|
|
|
data_file = request.files.get("data")
|
|
|
|
freq = request.form.get('freq')
|
|
|
|
period = request.form.get('period')
|
|
|
|
try:
|
|
|
|
data = pd.read_csv(data_file)
|
|
|
|
logger.info(data.shape)
|
2022-12-12 09:51:49 +08:00
|
|
|
rest = run_prophet(data, period=int(period), freq=freq)
|
2022-12-07 10:43:52 +08:00
|
|
|
logger.info(rest.columns)
|
|
|
|
rest['ds'] = rest['ds'].apply(str)
|
|
|
|
rest['yhat'] = rest['yhat'].apply(str)
|
2022-12-08 16:22:29 +08:00
|
|
|
|
2022-12-07 10:43:52 +08:00
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error: {e}")
|
2022-12-08 16:22:29 +08:00
|
|
|
resp_info["msg"] = str(e)
|
|
|
|
resp_info["code"] = 406
|
|
|
|
else:
|
|
|
|
resp_info["code"] = 200
|
2022-12-08 17:39:29 +08:00
|
|
|
if resp_info.get("code") == 200:
|
2022-12-12 09:23:45 +08:00
|
|
|
resp = make_response(stream_with_context(generate(rest)))
|
2022-12-09 18:27:43 +08:00
|
|
|
resp.headers["Content-Disposition"] = f"attachment; filename={file_name}"
|
2022-12-12 09:23:45 +08:00
|
|
|
resp.headers["Content-type"] = "text/csv;charset=utf-8"
|
2022-12-08 17:39:29 +08:00
|
|
|
else:
|
2022-12-09 11:45:56 +08:00
|
|
|
resp = make_response(json.dumps(resp_info))
|
2022-12-08 16:22:29 +08:00
|
|
|
resp.status_code = 200
|
|
|
|
return resp
|
2022-12-07 10:43:52 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2022-12-12 10:42:03 +08:00
|
|
|
app.run(host='0.0.0.0', port=8901, debug=False)
|