修改response
This commit is contained in:
parent
df82641305
commit
5c75e817b0
200
run.py
200
run.py
|
@ -5,107 +5,110 @@ import numpy as np
|
||||||
import json
|
import json
|
||||||
from logzero import logger
|
from logzero import logger
|
||||||
|
|
||||||
from house_price.house_price_predcition import run_boston_price
|
import io
|
||||||
from ocean_wave.wave_height_mlp import predict_wave_height
|
# from house_price.house_price_predcition import run_boston_price
|
||||||
|
# from ocean_wave.wave_height_mlp import predict_wave_height
|
||||||
from prophet_predict.prophet_predict import run_prophet
|
from prophet_predict.prophet_predict import run_prophet
|
||||||
|
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
@app.route('/house_price', methods=["POST"])
|
|
||||||
def predict_price():
|
|
||||||
resp_info = dict()
|
|
||||||
if request.method == 'POST':
|
|
||||||
eta = request.form.get('eta', 0.05)
|
|
||||||
max_depth = request.form.get('max_depth', 10)
|
|
||||||
subsample = request.form.get('subsample', 0.7)
|
|
||||||
cosample_bytree = request.form.get('cosample_bytree', 0.8)
|
|
||||||
num_boost_round = int(request.form.get('num_boost_round', 1000))
|
|
||||||
early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200))
|
|
||||||
train_data = request.files.get('train_data', None)
|
|
||||||
test_data = request.files.get('test_data', None)
|
|
||||||
logger.info(train_data)
|
|
||||||
params = {
|
|
||||||
"eta": float(eta),
|
|
||||||
"max_depth": int(max_depth),
|
|
||||||
"subsample": float(subsample),
|
|
||||||
"cosample_bytree": float(cosample_bytree)
|
|
||||||
}
|
|
||||||
if not train_data:
|
|
||||||
train_data = None
|
|
||||||
else:
|
|
||||||
train_data = pd.read_csv(train_data)
|
|
||||||
if test_data is None or pd.read_csv(test_data).shape[0] == 0:
|
|
||||||
resp_info["msg"] = "测试数据为空"
|
|
||||||
resp_info["code"] = 406
|
|
||||||
else:
|
|
||||||
test_data = pd.read_csv(test_data)
|
|
||||||
try:
|
|
||||||
if train_data is None:
|
|
||||||
rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params)
|
|
||||||
else:
|
|
||||||
rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error: {e}")
|
|
||||||
resp_info["msg"] = str(e)
|
|
||||||
resp_info["code"] = 406
|
|
||||||
else:
|
|
||||||
resp_info["code"] = 200
|
|
||||||
resp_info["data"] = rst.to_csv()
|
|
||||||
resp_info["dtype"] = "csv"
|
|
||||||
resp = make_response(json.dumps(resp_info))
|
|
||||||
resp.status_code = 200
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
# @app.route('/house_price', methods=["POST"])
|
||||||
@app.route('/ocean_wave_height', methods=["POST"])
|
# def predict_price():
|
||||||
def predict_height():
|
# resp_info = dict()
|
||||||
resp_info = dict()
|
# if request.method == 'POST':
|
||||||
if request.method == 'POST':
|
# eta = request.form.get('eta', 0.05)
|
||||||
num_units = int(request.form.get('num_units', 8))
|
# max_depth = request.form.get('max_depth', 10)
|
||||||
activation = request.form.get('activation', 'relu')
|
# subsample = request.form.get('subsample', 0.7)
|
||||||
lr = float(request.form.get('learning_rate', 0.01))
|
# cosample_bytree = request.form.get('cosample_bytree', 0.8)
|
||||||
loss = request.form.get('loss', 'mae')
|
# num_boost_round = int(request.form.get('num_boost_round', 1000))
|
||||||
epochs = int(request.form.get('num_boost_round', 100))
|
# early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200))
|
||||||
train_data = request.files.get('train_data', None)
|
# train_data = request.files.get('train_data', None)
|
||||||
WVHT_1 = float(request.form.get("WVHT_1", None))
|
# test_data = request.files.get('test_data', None)
|
||||||
WDIR_1 = float(request.form.get("WDIR_1", None))
|
# logger.info(train_data)
|
||||||
WSPD_1 = float(request.form.get("WSPD_1", None))
|
# params = {
|
||||||
WDIR_2 = float(request.form.get("WDIR_2", None))
|
# "eta": float(eta),
|
||||||
WSPD_2 = float(request.form.get("WSPD_2", None))
|
# "max_depth": int(max_depth),
|
||||||
WDIR = float(request.form.get("WDIR", None))
|
# "subsample": float(subsample),
|
||||||
WSPD = float(request.form.get("WSPD", None))
|
# "cosample_bytree": float(cosample_bytree)
|
||||||
x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD]
|
# }
|
||||||
x_test = np.array([x_test])
|
# if not train_data:
|
||||||
logger.info(f"test data: {x_test}")
|
# train_data = None
|
||||||
if not train_data:
|
# else:
|
||||||
train_data = None
|
# train_data = pd.read_csv(train_data)
|
||||||
else:
|
# if test_data is None or pd.read_csv(test_data).shape[0] == 0:
|
||||||
try:
|
# resp_info["msg"] = "测试数据为空"
|
||||||
train_data = pd.read_csv(train_data)
|
# resp_info["code"] = 406
|
||||||
except Exception as e:
|
# else:
|
||||||
logger.error(f"Error: {e}")
|
# test_data = pd.read_csv(test_data)
|
||||||
resp_info["msg"] = str(e)
|
# try:
|
||||||
resp_info["code"] = 406
|
# if train_data is None:
|
||||||
train_data = None
|
# rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params)
|
||||||
try:
|
# else:
|
||||||
rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test)
|
# rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error: {e}")
|
# logger.error(f"Error: {e}")
|
||||||
resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查"
|
# resp_info["msg"] = str(e)
|
||||||
resp_info["code"] = 406
|
# resp_info["code"] = 406
|
||||||
else:
|
# else:
|
||||||
resp_info["code"] = 200
|
# resp_info["code"] = 200
|
||||||
resp_info["data"] = rst
|
# resp_info["data"] = rst.to_csv()
|
||||||
resp_info["dtype"] = TEXT
|
# resp_info["dtype"] = "csv"
|
||||||
resp = make_response(json.dumps(resp_info))
|
# resp = make_response(json.dumps(resp_info))
|
||||||
resp.status_code = 200
|
# resp.status_code = 200
|
||||||
return resp
|
# return resp
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# @app.route('/ocean_wave_height', methods=["POST"])
|
||||||
|
# def predict_height():
|
||||||
|
# resp_info = dict()
|
||||||
|
# if request.method == 'POST':
|
||||||
|
# num_units = int(request.form.get('num_units', 8))
|
||||||
|
# activation = request.form.get('activation', 'relu')
|
||||||
|
# lr = float(request.form.get('learning_rate', 0.01))
|
||||||
|
# loss = request.form.get('loss', 'mae')
|
||||||
|
# epochs = int(request.form.get('num_boost_round', 100))
|
||||||
|
# train_data = request.files.get('train_data', None)
|
||||||
|
# WVHT_1 = float(request.form.get("WVHT_1", None))
|
||||||
|
# WDIR_1 = float(request.form.get("WDIR_1", None))
|
||||||
|
# WSPD_1 = float(request.form.get("WSPD_1", None))
|
||||||
|
# WDIR_2 = float(request.form.get("WDIR_2", None))
|
||||||
|
# WSPD_2 = float(request.form.get("WSPD_2", None))
|
||||||
|
# WDIR = float(request.form.get("WDIR", None))
|
||||||
|
# WSPD = float(request.form.get("WSPD", None))
|
||||||
|
# x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD]
|
||||||
|
# x_test = np.array([x_test])
|
||||||
|
# logger.info(f"test data: {x_test}")
|
||||||
|
# if not train_data:
|
||||||
|
# train_data = None
|
||||||
|
# else:
|
||||||
|
# try:
|
||||||
|
# train_data = pd.read_csv(train_data)
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Error: {e}")
|
||||||
|
# resp_info["msg"] = str(e)
|
||||||
|
# resp_info["code"] = 406
|
||||||
|
# train_data = None
|
||||||
|
# try:
|
||||||
|
# rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test)
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Error: {e}")
|
||||||
|
# resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查"
|
||||||
|
# resp_info["code"] = 406
|
||||||
|
# else:
|
||||||
|
# resp_info["code"] = 200
|
||||||
|
# resp_info["data"] = rst
|
||||||
|
# resp_info["dtype"] = TEXT
|
||||||
|
# resp = make_response(json.dumps(resp_info))
|
||||||
|
# resp.status_code = 200
|
||||||
|
# return resp
|
||||||
|
|
||||||
|
|
||||||
@app.route("/prophet/", methods=["POST"])
|
@app.route("/prophet/", methods=["POST"])
|
||||||
def run_ts_predict():
|
def run_ts_predict():
|
||||||
resp_info = dict()
|
resp_info = dict()
|
||||||
|
file_name = "rest.xlsx"
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
data_file = request.files.get("data")
|
data_file = request.files.get("data")
|
||||||
freq = request.form.get('freq')
|
freq = request.form.get('freq')
|
||||||
|
@ -125,13 +128,22 @@ def run_ts_predict():
|
||||||
resp_info["msg"] = str(e)
|
resp_info["msg"] = str(e)
|
||||||
resp_info["code"] = 406
|
resp_info["code"] = 406
|
||||||
else:
|
else:
|
||||||
|
out = io.BytesIO()
|
||||||
|
writer = pd.ExcelWriter(out, engine='xlsxwriter')
|
||||||
|
rest.to_excel(excel_writer=writer, sheet_name='Sheet1', index=False)
|
||||||
|
writer.save()
|
||||||
|
writer.close()
|
||||||
resp_info["code"] = 200
|
resp_info["code"] = 200
|
||||||
resp_info["data"] = rest.to_csv()
|
resp_info["data"] = out.getvalue()
|
||||||
resp_info["dtype"] = "csv"
|
if resp_info.get("code") == 200:
|
||||||
resp = make_response(json.dumps(resp_info))
|
resp = make_response(resp_info["data"])
|
||||||
|
resp.headers["Content-Disposition"] = "attachment; filename*=utf-8''{}".format(file_name)
|
||||||
|
resp.headers["Content-type"] = "application/x-xlsx"
|
||||||
|
else:
|
||||||
|
resp = make_response(resp_info)
|
||||||
resp.status_code = 200
|
resp.status_code = 200
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(host='0.0.0.0', port=8901, debug=False)
|
app.run(host='0.0.0.0', port=8901, debug=True)
|
||||||
|
|
Loading…
Reference in New Issue