From 08013b7628f72140e803a760c78cb6f5cb201422 Mon Sep 17 00:00:00 2001 From: zhaojinghao Date: Fri, 9 Dec 2022 10:27:54 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9response?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run.py | 191 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 100 insertions(+), 91 deletions(-) diff --git a/run.py b/run.py index 334c06d..de3b944 100644 --- a/run.py +++ b/run.py @@ -6,103 +6,112 @@ import json from logzero import logger import io -# from house_price.house_price_predcition import run_boston_price -# from ocean_wave.wave_height_mlp import predict_wave_height +from house_price.house_price_predcition import run_boston_price +from ocean_wave.wave_height_mlp import predict_wave_height from prophet_predict.prophet_predict import run_prophet TEXT = "text" app = Flask(__name__) -# @app.route('/house_price', methods=["POST"]) -# def predict_price(): -# resp_info = dict() -# if request.method == 'POST': -# eta = request.form.get('eta', 0.05) -# max_depth = request.form.get('max_depth', 10) -# subsample = request.form.get('subsample', 0.7) -# cosample_bytree = request.form.get('cosample_bytree', 0.8) -# num_boost_round = int(request.form.get('num_boost_round', 1000)) -# early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200)) -# train_data = request.files.get('train_data', None) -# test_data = request.files.get('test_data', None) -# logger.info(train_data) -# params = { -# "eta": float(eta), -# "max_depth": int(max_depth), -# "subsample": float(subsample), -# "cosample_bytree": float(cosample_bytree) -# } -# if not train_data: -# train_data = None -# else: -# train_data = pd.read_csv(train_data) -# if test_data is None or pd.read_csv(test_data).shape[0] == 0: -# resp_info["msg"] = "测试数据为空" -# resp_info["code"] = 406 -# else: -# test_data = pd.read_csv(test_data) -# try: -# if train_data is None: -# rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params) -# else: -# rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params) -# except Exception as e: -# logger.error(f"Error: {e}") -# resp_info["msg"] = str(e) -# resp_info["code"] = 406 -# else: -# resp_info["code"] = 200 -# resp_info["data"] = rst.to_csv() -# resp_info["dtype"] = "csv" -# resp = make_response(json.dumps(resp_info)) -# resp.status_code = 200 -# return resp -# -# -# @app.route('/ocean_wave_height', methods=["POST"]) -# def predict_height(): -# resp_info = dict() -# if request.method == 'POST': -# num_units = int(request.form.get('num_units', 8)) -# activation = request.form.get('activation', 'relu') -# lr = float(request.form.get('learning_rate', 0.01)) -# loss = request.form.get('loss', 'mae') -# epochs = int(request.form.get('num_boost_round', 100)) -# train_data = request.files.get('train_data', None) -# WVHT_1 = float(request.form.get("WVHT_1", None)) -# WDIR_1 = float(request.form.get("WDIR_1", None)) -# WSPD_1 = float(request.form.get("WSPD_1", None)) -# WDIR_2 = float(request.form.get("WDIR_2", None)) -# WSPD_2 = float(request.form.get("WSPD_2", None)) -# WDIR = float(request.form.get("WDIR", None)) -# WSPD = float(request.form.get("WSPD", None)) -# x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD] -# x_test = np.array([x_test]) -# logger.info(f"test data: {x_test}") -# if not train_data: -# train_data = None -# else: -# try: -# train_data = pd.read_csv(train_data) -# except Exception as e: -# logger.error(f"Error: {e}") -# resp_info["msg"] = str(e) -# resp_info["code"] = 406 -# train_data = None -# try: -# rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test) -# except Exception as e: -# logger.error(f"Error: {e}") -# resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查" -# resp_info["code"] = 406 -# else: -# resp_info["code"] = 200 -# resp_info["data"] = rst -# resp_info["dtype"] = TEXT -# resp = make_response(json.dumps(resp_info)) -# resp.status_code = 200 -# return resp +@app.route('/house_price', methods=["POST"]) +def predict_price(): + resp_info = dict() + if request.method == 'POST': + eta = request.form.get('eta', 0.05) + max_depth = request.form.get('max_depth', 10) + subsample = request.form.get('subsample', 0.7) + cosample_bytree = request.form.get('cosample_bytree', 0.8) + num_boost_round = int(request.form.get('num_boost_round', 1000)) + early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200)) + train_data = request.files.get('train_data', None) + test_data = request.files.get('test_data', None) + logger.info(train_data) + params = { + "eta": float(eta), + "max_depth": int(max_depth), + "subsample": float(subsample), + "cosample_bytree": float(cosample_bytree) + } + if not train_data: + train_data = None + else: + train_data = pd.read_csv(train_data) + if test_data is None or pd.read_csv(test_data).shape[0] == 0: + resp_info["msg"] = "测试数据为空" + resp_info["code"] = 406 + else: + test_data = pd.read_csv(test_data) + try: + if train_data is None: + rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params) + else: + rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params) + except Exception as e: + logger.error(f"Error: {e}") + resp_info["msg"] = str(e) + resp_info["code"] = 406 + else: + out = io.BytesIO() + writer = pd.ExcelWriter(out, engine='xlsxwriter') + rst.to_excel(excel_writer=writer, sheet_name='Sheet1', index=False) + writer.save() + writer.close() + resp_info["code"] = 200 + resp_info["data"] = out.getvalue() + if resp_info["code"] == 200: + resp = make_response(resp_info["data"]) + resp.headers["Content-Disposition"] = "attachment; filename*=utf-8''{}".format("house_price.xlsx") + resp.headers["Content-type"] = "application/x-xlsx" + else: + resp = make_response(resp_info) + resp.status_code = 200 + return resp + + +@app.route('/ocean_wave_height', methods=["POST"]) +def predict_height(): + resp_info = dict() + if request.method == 'POST': + num_units = int(request.form.get('num_units', 8)) + activation = request.form.get('activation', 'relu') + lr = float(request.form.get('learning_rate', 0.01)) + loss = request.form.get('loss', 'mae') + epochs = int(request.form.get('num_boost_round', 100)) + train_data = request.files.get('train_data', None) + WVHT_1 = float(request.form.get("WVHT_1", None)) + WDIR_1 = float(request.form.get("WDIR_1", None)) + WSPD_1 = float(request.form.get("WSPD_1", None)) + WDIR_2 = float(request.form.get("WDIR_2", None)) + WSPD_2 = float(request.form.get("WSPD_2", None)) + WDIR = float(request.form.get("WDIR", None)) + WSPD = float(request.form.get("WSPD", None)) + x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD] + x_test = np.array([x_test]) + logger.info(f"test data: {x_test}") + if not train_data: + train_data = None + else: + try: + train_data = pd.read_csv(train_data) + except Exception as e: + logger.error(f"Error: {e}") + resp_info["msg"] = str(e) + resp_info["code"] = 406 + train_data = None + try: + rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test) + except Exception as e: + logger.error(f"Error: {e}") + resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查" + resp_info["code"] = 406 + else: + resp_info["code"] = 200 + resp_info["data"] = rst + resp_info["dtype"] = TEXT + resp = make_response(json.dumps(resp_info)) + resp.status_code = 200 + return resp @app.route("/prophet/", methods=["POST"])