# -*-coding:utf-8-*- from flask import Flask, request, make_response import pandas as pd import numpy as np import json from logzero import logger import io # from house_price.house_price_predcition import run_boston_price # from ocean_wave.wave_height_mlp import predict_wave_height from prophet_predict.prophet_predict import run_prophet TEXT = "text" app = Flask(__name__) # @app.route('/house_price', methods=["POST"]) # def predict_price(): # resp_info = dict() # if request.method == 'POST': # eta = request.form.get('eta', 0.05) # max_depth = request.form.get('max_depth', 10) # subsample = request.form.get('subsample', 0.7) # cosample_bytree = request.form.get('cosample_bytree', 0.8) # num_boost_round = int(request.form.get('num_boost_round', 1000)) # early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200)) # train_data = request.files.get('train_data', None) # test_data = request.files.get('test_data', None) # logger.info(train_data) # params = { # "eta": float(eta), # "max_depth": int(max_depth), # "subsample": float(subsample), # "cosample_bytree": float(cosample_bytree) # } # if not train_data: # train_data = None # else: # train_data = pd.read_csv(train_data) # if test_data is None or pd.read_csv(test_data).shape[0] == 0: # resp_info["msg"] = "测试数据为空" # resp_info["code"] = 406 # else: # test_data = pd.read_csv(test_data) # try: # if train_data is None: # rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params) # else: # rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params) # except Exception as e: # logger.error(f"Error: {e}") # resp_info["msg"] = str(e) # resp_info["code"] = 406 # else: # resp_info["code"] = 200 # resp_info["data"] = rst.to_csv() # resp_info["dtype"] = "csv" # resp = make_response(json.dumps(resp_info)) # resp.status_code = 200 # return resp # # # @app.route('/ocean_wave_height', methods=["POST"]) # def predict_height(): # resp_info = dict() # if request.method == 'POST': # num_units = int(request.form.get('num_units', 8)) # activation = request.form.get('activation', 'relu') # lr = float(request.form.get('learning_rate', 0.01)) # loss = request.form.get('loss', 'mae') # epochs = int(request.form.get('num_boost_round', 100)) # train_data = request.files.get('train_data', None) # WVHT_1 = float(request.form.get("WVHT_1", None)) # WDIR_1 = float(request.form.get("WDIR_1", None)) # WSPD_1 = float(request.form.get("WSPD_1", None)) # WDIR_2 = float(request.form.get("WDIR_2", None)) # WSPD_2 = float(request.form.get("WSPD_2", None)) # WDIR = float(request.form.get("WDIR", None)) # WSPD = float(request.form.get("WSPD", None)) # x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD] # x_test = np.array([x_test]) # logger.info(f"test data: {x_test}") # if not train_data: # train_data = None # else: # try: # train_data = pd.read_csv(train_data) # except Exception as e: # logger.error(f"Error: {e}") # resp_info["msg"] = str(e) # resp_info["code"] = 406 # train_data = None # try: # rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test) # except Exception as e: # logger.error(f"Error: {e}") # resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查" # resp_info["code"] = 406 # else: # resp_info["code"] = 200 # resp_info["data"] = rst # resp_info["dtype"] = TEXT # resp = make_response(json.dumps(resp_info)) # resp.status_code = 200 # return resp @app.route("/prophet/", methods=["POST"]) def run_ts_predict(): resp_info = dict() file_name = "rest.xlsx" if request.method == "POST": data_file = request.files.get("data") freq = request.form.get('freq') period = request.form.get('period') try: data = pd.read_csv(data_file) logger.info(data.shape) rest = run_prophet(data, period=int(period), freq=freq)[[ 'ds', 'yhat' ]] logger.info(rest.columns) rest['ds'] = rest['ds'].apply(str) rest['yhat'] = rest['yhat'].apply(str) except Exception as e: logger.error(f"Error: {e}") resp_info["msg"] = str(e) resp_info["code"] = 406 else: out = io.BytesIO() writer = pd.ExcelWriter(out, engine='xlsxwriter') rest.to_excel(excel_writer=writer, sheet_name='Sheet1', index=False) writer.save() writer.close() resp_info["code"] = 200 resp_info["data"] = out.getvalue() if resp_info.get("code") == 200: resp = make_response(resp_info["data"]) resp.headers["Content-Disposition"] = "attachment; filename*=utf-8''{}".format(file_name) resp.headers["Content-type"] = "application/x-xlsx" else: resp = make_response(resp_info) resp.status_code = 200 return resp if __name__ == '__main__': app.run(host='0.0.0.0', port=8901, debug=True)