diff --git a/run.py b/run.py index 9cc6caa..334c06d 100644 --- a/run.py +++ b/run.py @@ -5,107 +5,110 @@ import numpy as np import json from logzero import logger -from house_price.house_price_predcition import run_boston_price -from ocean_wave.wave_height_mlp import predict_wave_height +import io +# from house_price.house_price_predcition import run_boston_price +# from ocean_wave.wave_height_mlp import predict_wave_height from prophet_predict.prophet_predict import run_prophet TEXT = "text" app = Flask(__name__) -@app.route('/house_price', methods=["POST"]) -def predict_price(): - resp_info = dict() - if request.method == 'POST': - eta = request.form.get('eta', 0.05) - max_depth = request.form.get('max_depth', 10) - subsample = request.form.get('subsample', 0.7) - cosample_bytree = request.form.get('cosample_bytree', 0.8) - num_boost_round = int(request.form.get('num_boost_round', 1000)) - early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200)) - train_data = request.files.get('train_data', None) - test_data = request.files.get('test_data', None) - logger.info(train_data) - params = { - "eta": float(eta), - "max_depth": int(max_depth), - "subsample": float(subsample), - "cosample_bytree": float(cosample_bytree) - } - if not train_data: - train_data = None - else: - train_data = pd.read_csv(train_data) - if test_data is None or pd.read_csv(test_data).shape[0] == 0: - resp_info["msg"] = "测试数据为空" - resp_info["code"] = 406 - else: - test_data = pd.read_csv(test_data) - try: - if train_data is None: - rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params) - else: - rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params) - except Exception as e: - logger.error(f"Error: {e}") - resp_info["msg"] = str(e) - resp_info["code"] = 406 - else: - resp_info["code"] = 200 - resp_info["data"] = rst.to_csv() - resp_info["dtype"] = "csv" - resp = make_response(json.dumps(resp_info)) - resp.status_code = 200 - return resp - -@app.route('/ocean_wave_height', methods=["POST"]) -def predict_height(): - resp_info = dict() - if request.method == 'POST': - num_units = int(request.form.get('num_units', 8)) - activation = request.form.get('activation', 'relu') - lr = float(request.form.get('learning_rate', 0.01)) - loss = request.form.get('loss', 'mae') - epochs = int(request.form.get('num_boost_round', 100)) - train_data = request.files.get('train_data', None) - WVHT_1 = float(request.form.get("WVHT_1", None)) - WDIR_1 = float(request.form.get("WDIR_1", None)) - WSPD_1 = float(request.form.get("WSPD_1", None)) - WDIR_2 = float(request.form.get("WDIR_2", None)) - WSPD_2 = float(request.form.get("WSPD_2", None)) - WDIR = float(request.form.get("WDIR", None)) - WSPD = float(request.form.get("WSPD", None)) - x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD] - x_test = np.array([x_test]) - logger.info(f"test data: {x_test}") - if not train_data: - train_data = None - else: - try: - train_data = pd.read_csv(train_data) - except Exception as e: - logger.error(f"Error: {e}") - resp_info["msg"] = str(e) - resp_info["code"] = 406 - train_data = None - try: - rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test) - except Exception as e: - logger.error(f"Error: {e}") - resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查" - resp_info["code"] = 406 - else: - resp_info["code"] = 200 - resp_info["data"] = rst - resp_info["dtype"] = TEXT - resp = make_response(json.dumps(resp_info)) - resp.status_code = 200 - return resp +# @app.route('/house_price', methods=["POST"]) +# def predict_price(): +# resp_info = dict() +# if request.method == 'POST': +# eta = request.form.get('eta', 0.05) +# max_depth = request.form.get('max_depth', 10) +# subsample = request.form.get('subsample', 0.7) +# cosample_bytree = request.form.get('cosample_bytree', 0.8) +# num_boost_round = int(request.form.get('num_boost_round', 1000)) +# early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200)) +# train_data = request.files.get('train_data', None) +# test_data = request.files.get('test_data', None) +# logger.info(train_data) +# params = { +# "eta": float(eta), +# "max_depth": int(max_depth), +# "subsample": float(subsample), +# "cosample_bytree": float(cosample_bytree) +# } +# if not train_data: +# train_data = None +# else: +# train_data = pd.read_csv(train_data) +# if test_data is None or pd.read_csv(test_data).shape[0] == 0: +# resp_info["msg"] = "测试数据为空" +# resp_info["code"] = 406 +# else: +# test_data = pd.read_csv(test_data) +# try: +# if train_data is None: +# rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params) +# else: +# rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params) +# except Exception as e: +# logger.error(f"Error: {e}") +# resp_info["msg"] = str(e) +# resp_info["code"] = 406 +# else: +# resp_info["code"] = 200 +# resp_info["data"] = rst.to_csv() +# resp_info["dtype"] = "csv" +# resp = make_response(json.dumps(resp_info)) +# resp.status_code = 200 +# return resp +# +# +# @app.route('/ocean_wave_height', methods=["POST"]) +# def predict_height(): +# resp_info = dict() +# if request.method == 'POST': +# num_units = int(request.form.get('num_units', 8)) +# activation = request.form.get('activation', 'relu') +# lr = float(request.form.get('learning_rate', 0.01)) +# loss = request.form.get('loss', 'mae') +# epochs = int(request.form.get('num_boost_round', 100)) +# train_data = request.files.get('train_data', None) +# WVHT_1 = float(request.form.get("WVHT_1", None)) +# WDIR_1 = float(request.form.get("WDIR_1", None)) +# WSPD_1 = float(request.form.get("WSPD_1", None)) +# WDIR_2 = float(request.form.get("WDIR_2", None)) +# WSPD_2 = float(request.form.get("WSPD_2", None)) +# WDIR = float(request.form.get("WDIR", None)) +# WSPD = float(request.form.get("WSPD", None)) +# x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD] +# x_test = np.array([x_test]) +# logger.info(f"test data: {x_test}") +# if not train_data: +# train_data = None +# else: +# try: +# train_data = pd.read_csv(train_data) +# except Exception as e: +# logger.error(f"Error: {e}") +# resp_info["msg"] = str(e) +# resp_info["code"] = 406 +# train_data = None +# try: +# rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test) +# except Exception as e: +# logger.error(f"Error: {e}") +# resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查" +# resp_info["code"] = 406 +# else: +# resp_info["code"] = 200 +# resp_info["data"] = rst +# resp_info["dtype"] = TEXT +# resp = make_response(json.dumps(resp_info)) +# resp.status_code = 200 +# return resp @app.route("/prophet/", methods=["POST"]) def run_ts_predict(): resp_info = dict() + file_name = "rest.xlsx" if request.method == "POST": data_file = request.files.get("data") freq = request.form.get('freq') @@ -125,13 +128,22 @@ def run_ts_predict(): resp_info["msg"] = str(e) resp_info["code"] = 406 else: + out = io.BytesIO() + writer = pd.ExcelWriter(out, engine='xlsxwriter') + rest.to_excel(excel_writer=writer, sheet_name='Sheet1', index=False) + writer.save() + writer.close() resp_info["code"] = 200 - resp_info["data"] = rest.to_csv() - resp_info["dtype"] = "csv" - resp = make_response(json.dumps(resp_info)) + resp_info["data"] = out.getvalue() + if resp_info.get("code") == 200: + resp = make_response(resp_info["data"]) + resp.headers["Content-Disposition"] = "attachment; filename*=utf-8''{}".format(file_name) + resp.headers["Content-type"] = "application/x-xlsx" + else: + resp = make_response(resp_info) resp.status_code = 200 return resp if __name__ == '__main__': - app.run(host='0.0.0.0', port=8901, debug=False) + app.run(host='0.0.0.0', port=8901, debug=True)