ai_platform_regression/run.py

138 lines
5.1 KiB
Python

# -*-coding:utf-8-*-
from flask import Flask, request, make_response
import pandas as pd
import numpy as np
import json
from logzero import logger
from house_price.house_price_predcition import run_boston_price
from ocean_wave.wave_height_mlp import predict_wave_height
from prophet_predict.prophet_predict import run_prophet
TEXT = "text"
app = Flask(__name__)
@app.route('/house_price', methods=["POST"])
def predict_price():
resp_info = dict()
if request.method == 'POST':
eta = request.form.get('eta', 0.05)
max_depth = request.form.get('max_depth', 10)
subsample = request.form.get('subsample', 0.7)
cosample_bytree = request.form.get('cosample_bytree', 0.8)
num_boost_round = int(request.form.get('num_boost_round', 1000))
early_stopping_rounds = int(request.form.get('early_stopping_rounds', 200))
train_data = request.files.get('train_data', None)
test_data = request.files.get('test_data', None)
logger.info(train_data)
params = {
"eta": float(eta),
"max_depth": int(max_depth),
"subsample": float(subsample),
"cosample_bytree": float(cosample_bytree)
}
if not train_data:
train_data = None
else:
train_data = pd.read_csv(train_data)
if test_data is None or pd.read_csv(test_data).shape[0] == 0:
resp_info["msg"] = "测试数据为空"
resp_info["code"] = 406
else:
test_data = pd.read_csv(test_data)
try:
if train_data is None:
rst = run_boston_price(test_data, None, num_boost_round, early_stopping_rounds, **params)
else:
rst = run_boston_price(test_data, train_data, num_boost_round, early_stopping_rounds, **params)
except Exception as e:
logger.error(f"Error: {e}")
resp_info["msg"] = str(e)
resp_info["code"] = 406
else:
resp_info["code"] = 200
resp_info["data"] = rst.to_csv()
resp_info["dtype"] = "csv"
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
@app.route('/ocean_wave_height', methods=["POST"])
def predict_height():
resp_info = dict()
if request.method == 'POST':
num_units = int(request.form.get('num_units', 8))
activation = request.form.get('activation', 'relu')
lr = float(request.form.get('learning_rate', 0.01))
loss = request.form.get('loss', 'mae')
epochs = int(request.form.get('num_boost_round', 100))
train_data = request.files.get('train_data', None)
WVHT_1 = float(request.form.get("WVHT_1", None))
WDIR_1 = float(request.form.get("WDIR_1", None))
WSPD_1 = float(request.form.get("WSPD_1", None))
WDIR_2 = float(request.form.get("WDIR_2", None))
WSPD_2 = float(request.form.get("WSPD_2", None))
WDIR = float(request.form.get("WDIR", None))
WSPD = float(request.form.get("WSPD", None))
x_test = [WVHT_1, WDIR_1, WSPD_1, WDIR_2, WSPD_2, WDIR, WSPD]
x_test = np.array([x_test])
logger.info(f"test data: {x_test}")
if not train_data:
train_data = None
else:
try:
train_data = pd.read_csv(train_data)
except Exception as e:
logger.error(f"Error: {e}")
resp_info["msg"] = str(e)
resp_info["code"] = 406
train_data = None
try:
rst = predict_wave_height(train_data, num_units, activation, lr, loss, epochs, x_test)
except Exception as e:
logger.error(f"Error: {e}")
resp_info["msg"] = "上传数据不符合海浪高度预测的规定文件示例,请检查"
resp_info["code"] = 406
else:
resp_info["code"] = 200
resp_info["data"] = rst
resp_info["dtype"] = TEXT
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
@app.route("/prophet/", methods=["POST"])
def run_ts_predict():
resp_info = dict()
if request.method == "POST":
data_file = request.files.get("data")
freq = request.form.get('freq')
period = request.form.get('period')
try:
data = pd.read_csv(data_file)
logger.info(data.shape)
rest = run_prophet(data, period=int(period), freq=freq)[[
'ds', 'yhat'
]]
logger.info(rest.columns)
rest['ds'] = rest['ds'].apply(str)
rest['yhat'] = rest['yhat'].apply(str)
except Exception as e:
logger.error(f"Error: {e}")
resp_info["msg"] = str(e)
resp_info["code"] = 406
else:
resp_info["code"] = 200
resp_info["data"] = rest.to_csv()
resp_info["dtype"] = "csv"
resp = make_response(json.dumps(resp_info))
resp.status_code = 200
return resp
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8901, debug=True)