# -*-coding:utf-8-*- from flask import Flask, request, make_response import pandas as pd import numpy as np import json from logzero import logger from flask import stream_with_context import io import csv from house_price.house_price_predcition import run_boston_price from ocean_wave.wave_height_mlp import predict_wave_height from prophet_predict.prophet_predict import run_prophet TEXT = "text" app = Flask(__name__) def generate(data: pd.DataFrame): # 用 StringIO 在内存中写,不会生成实际文件 out = io.StringIO() w = csv.writer(out) w.writerow(data.columns.tolist()) # 先写入表头 yield out.getvalue() out.seek(0) out.truncate(0) for i in range(data.shape[0]): # 对于 data 中的每一条 w.writerow(data.iloc[i].values.tolist()) # 传入的是一个数组 ['xxx','xxx@xxx.xxx'] csv.writer 会把它处理成逗号分隔的一行 # 需要注意的是传入仅一个字符串 '' 时,会被逐字符分割,所以要写成 ['xxx'] 的形式 yield out.getvalue() # 返回写入的值 out.seek(0) # io流的指针回到起点 out.truncate(0) # 删去指针之后的部分,即清空所有写入的内容,准备下一行的写入 @app.route('/house_price/', methods=["POST"]) def predict_price(): resp_info = dict() if request.method == 'POST': eta = request.form.get('eta', 0.05) max_depth = request.form.get('max_depth', 10) subsample = request.form.get('subsample', 0.7) cosample_bytree = request.form.get('cosample_bytree', 0.8) num_boost_round = int(request.form.get('num_boost_round', 1000)) early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200)) train_data = request.files.get('train_data', None) test_data = request.files.get('test_data', None) logger.info(train_data) params = { "eta": float(eta), "max_depth": int(max_depth), "subsample": float(subsample), "cosample_bytree": float(cosample_bytree) } if not train_data: train_data = None else: train_data = pd.read_csv(train_data) if test_data is None: resp_info["msg"] = "测试数据为空" resp_info["code"] = 406 else: test_data = pd.read_csv(test_data) if test_data.shape[0] == 0: resp_info["msg"] = "测试数据为空" resp_info["code"] = 406 else: try: if train_data is None: rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params) else: rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params) except Exception as e: logger.error(f"Error: {e}") resp_info["msg"] = str(e) resp_info["code"] = 406 else: resp_info["code"] = 200 if resp_info["code"] == 200: resp = make_response(stream_with_context(generate(rst))) resp.headers["Content-Disposition"] = "attachment; filename=house_price.csv" resp.headers["Content-type"] = "text/csv" else: resp = make_response(json.dumps(resp_info)) resp.status_code = 200 return resp @app.route('/ocean_wave_height/', methods=["POST"]) def predict_height(): resp_info = dict() if request.method == 'POST': num_units = int(request.form.get('num_units', 8)) activation = request.form.get('activation', 'relu') lr = float(request.form.get('learning_rate', 0.01)) loss = request.form.get('loss', 'mae') epochs = int(request.form.get('num_boost_round', 100)) train_data = request.files.get('train_data', None) WVHT_1 = float(request.form.get("WVHT_1", None)) WDIR_1 = float(request.form.get("WDIR_1", None)) WSPD_1 = float(request.form.get("WSPD_1", None)) WDIR_2 = float(request.form.get("WDIR_2", None)) WSPD_2 = float(request.form.get("WSPD_2", None)) WDIR = float(request.form.get("WDIR", None)) WSPD = float(request.form.get("WSPD", None)) x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD] x_test = np.array([x_test]) logger.info(f"test data: {x_test}") if not train_data: train_data = None else: try: train_data = pd.read_csv(train_data) except Exception as e: logger.error(f"Error: {e}") resp_info["msg"] = str(e) resp_info["code"] = 406 train_data = None try: rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test) except Exception as e: logger.error(f"Error: {e}") resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查" resp_info["code"] = 406 else: resp_info["code"] = 200 resp_info["data"] = str(rst) resp_info["dtype"] = TEXT resp = make_response(json.dumps(resp_info)) resp.status_code = 200 return resp @app.route("/prophet/", methods=["POST"]) def run_ts_predict(): resp_info = dict() file_name = "result.csv" if request.method == "POST": data_file = request.files.get("data") freq = request.form.get('freq') period = request.form.get('period') try: data = pd.read_csv(data_file) logger.info(data.shape) rest = run_prophet(data, period=int(period), freq=freq) logger.info(rest.columns) rest['ds'] = rest['ds'].apply(str) rest['yhat'] = rest['yhat'].apply(str) except Exception as e: logger.error(f"Error: {e}") resp_info["msg"] = str(e) resp_info["code"] = 406 else: resp_info["code"] = 200 if resp_info.get("code") == 200: resp = make_response(stream_with_context(generate(rest))) resp.headers["Content-Disposition"] = f"attachment; filename={file_name}" resp.headers["Content-type"] = "text/csv;charset=utf-8" else: resp = make_response(json.dumps(resp_info)) resp.status_code = 200 return resp if __name__ == '__main__': app.run(host='0.0.0.0', port=8901, debug=False)