修改response

This commit is contained in:
zhaojinghao 2022-12-09 10:27:54 +08:00
parent 7dddaf1dea
commit 08013b7628
1 changed files with 100 additions and 91 deletions

191
run.py
View File

@ -6,103 +6,112 @@ import json
from logzero import logger from logzero import logger
import io import io
# from house_price.house_price_predcition import run_boston_price from house_price.house_price_predcition import run_boston_price
# from ocean_wave.wave_height_mlp import predict_wave_height from ocean_wave.wave_height_mlp import predict_wave_height
from prophet_predict.prophet_predict import run_prophet from prophet_predict.prophet_predict import run_prophet
TEXT = "text" TEXT = "text"
app = Flask(__name__) app = Flask(__name__)
# @app.route('/house_price', methods=["POST"]) @app.route('/house_price', methods=["POST"])
# def predict_price(): def predict_price():
# resp_info = dict() resp_info = dict()
# if request.method == 'POST': if request.method == 'POST':
# eta = request.form.get('eta', 0.05) eta = request.form.get('eta', 0.05)
# max_depth = request.form.get('max_depth', 10) max_depth = request.form.get('max_depth', 10)
# subsample = request.form.get('subsample', 0.7) subsample = request.form.get('subsample', 0.7)
# cosample_bytree = request.form.get('cosample_bytree', 0.8) cosample_bytree = request.form.get('cosample_bytree', 0.8)
# num_boost_round = int(request.form.get('num_boost_round', 1000)) num_boost_round = int(request.form.get('num_boost_round', 1000))
# early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200)) early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200))
# train_data = request.files.get('train_data', None) train_data = request.files.get('train_data', None)
# test_data = request.files.get('test_data', None) test_data = request.files.get('test_data', None)
# logger.info(train_data) logger.info(train_data)
# params = { params = {
# "eta": float(eta), "eta": float(eta),
# "max_depth": int(max_depth), "max_depth": int(max_depth),
# "subsample": float(subsample), "subsample": float(subsample),
# "cosample_bytree": float(cosample_bytree) "cosample_bytree": float(cosample_bytree)
# } }
# if not train_data: if not train_data:
# train_data = None train_data = None
# else: else:
# train_data = pd.read_csv(train_data) train_data = pd.read_csv(train_data)
# if test_data is None or pd.read_csv(test_data).shape[0] == 0: if test_data is None or pd.read_csv(test_data).shape[0] == 0:
# resp_info["msg"] = "测试数据为空" resp_info["msg"] = "测试数据为空"
# resp_info["code"] = 406 resp_info["code"] = 406
# else: else:
# test_data = pd.read_csv(test_data) test_data = pd.read_csv(test_data)
# try: try:
# if train_data is None: if train_data is None:
# rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params) rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params)
# else: else:
# rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params) rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params)
# except Exception as e: except Exception as e:
# logger.error(f"Error: {e}") logger.error(f"Error: {e}")
# resp_info["msg"] = str(e) resp_info["msg"] = str(e)
# resp_info["code"] = 406 resp_info["code"] = 406
# else: else:
# resp_info["code"] = 200 out = io.BytesIO()
# resp_info["data"] = rst.to_csv() writer = pd.ExcelWriter(out, engine='xlsxwriter')
# resp_info["dtype"] = "csv" rst.to_excel(excel_writer=writer, sheet_name='Sheet1', index=False)
# resp = make_response(json.dumps(resp_info)) writer.save()
# resp.status_code = 200 writer.close()
# return resp resp_info["code"] = 200
# resp_info["data"] = out.getvalue()
# if resp_info["code"] == 200:
# @app.route('/ocean_wave_height', methods=["POST"]) resp = make_response(resp_info["data"])
# def predict_height(): resp.headers["Content-Disposition"] = "attachment; filename*=utf-8''{}".format("house_price.xlsx")
# resp_info = dict() resp.headers["Content-type"] = "application/x-xlsx"
# if request.method == 'POST': else:
# num_units = int(request.form.get('num_units', 8)) resp = make_response(resp_info)
# activation = request.form.get('activation', 'relu') resp.status_code = 200
# lr = float(request.form.get('learning_rate', 0.01)) return resp
# loss = request.form.get('loss', 'mae')
# epochs = int(request.form.get('num_boost_round', 100))
# train_data = request.files.get('train_data', None) @app.route('/ocean_wave_height', methods=["POST"])
# WVHT_1 = float(request.form.get("WVHT_1", None)) def predict_height():
# WDIR_1 = float(request.form.get("WDIR_1", None)) resp_info = dict()
# WSPD_1 = float(request.form.get("WSPD_1", None)) if request.method == 'POST':
# WDIR_2 = float(request.form.get("WDIR_2", None)) num_units = int(request.form.get('num_units', 8))
# WSPD_2 = float(request.form.get("WSPD_2", None)) activation = request.form.get('activation', 'relu')
# WDIR = float(request.form.get("WDIR", None)) lr = float(request.form.get('learning_rate', 0.01))
# WSPD = float(request.form.get("WSPD", None)) loss = request.form.get('loss', 'mae')
# x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD] epochs = int(request.form.get('num_boost_round', 100))
# x_test = np.array([x_test]) train_data = request.files.get('train_data', None)
# logger.info(f"test data: {x_test}") WVHT_1 = float(request.form.get("WVHT_1", None))
# if not train_data: WDIR_1 = float(request.form.get("WDIR_1", None))
# train_data = None WSPD_1 = float(request.form.get("WSPD_1", None))
# else: WDIR_2 = float(request.form.get("WDIR_2", None))
# try: WSPD_2 = float(request.form.get("WSPD_2", None))
# train_data = pd.read_csv(train_data) WDIR = float(request.form.get("WDIR", None))
# except Exception as e: WSPD = float(request.form.get("WSPD", None))
# logger.error(f"Error: {e}") x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD]
# resp_info["msg"] = str(e) x_test = np.array([x_test])
# resp_info["code"] = 406 logger.info(f"test data: {x_test}")
# train_data = None if not train_data:
# try: train_data = None
# rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test) else:
# except Exception as e: try:
# logger.error(f"Error: {e}") train_data = pd.read_csv(train_data)
# resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查" except Exception as e:
# resp_info["code"] = 406 logger.error(f"Error: {e}")
# else: resp_info["msg"] = str(e)
# resp_info["code"] = 200 resp_info["code"] = 406
# resp_info["data"] = rst train_data = None
# resp_info["dtype"] = TEXT try:
# resp = make_response(json.dumps(resp_info)) rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test)
# resp.status_code = 200 except Exception as e:
# return resp logger.error(f"Error: {e}")
resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查"
resp_info["code"] = 406
else:
resp_info["code"] = 200
resp_info["data"] = rst
resp_info["dtype"] = TEXT
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
@app.route("/prophet/", methods=["POST"]) @app.route("/prophet/", methods=["POST"])