Compare commits
2 Commits
54e8b0411b
...
8d891b420e
Author | SHA1 | Date |
---|---|---|
|
8d891b420e | |
|
68e3076f6a |
|
@ -1,2 +1,4 @@
|
|||
*.log
|
||||
tmp/
|
||||
tmp/
|
||||
datas/
|
||||
imgs/
|
Binary file not shown.
BIN
download.zip
BIN
download.zip
Binary file not shown.
1104
fast_api_run.py
1104
fast_api_run.py
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
|
@ -1,11 +1,11 @@
|
|||
A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Char
|
||||
16.6457787398883,62.0967741935484,37.9032258064516,68.38,5.6,1.12,0.42,24.05,0.982743492249196,0.263783269961977,0.0140391927464171,0.3,5.0,2.0,650,48.44904
|
||||
8.38943552563439,42.4194460146976,57.5805539853024,82.63,4.08,1.09,0.33,18.56,0.592520876195087,0.16846181774174,0.011306858456804,20.0,5.0,0.2,490,78.69
|
||||
13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,600,57.0752
|
||||
17.44,39.81,60.19,78.08,3.95,0.65,2.87,14.45,0.607069672131148,0.138799948770492,0.0071355386416861,20.0,5.0,0.2,510,79.9
|
||||
24.2816545626776,49.402279677509,50.597720322491,73.99,5.66,1.14,0.49,18.72,0.917961886741452,0.189755372347615,0.0132064178556948,15.0,5.0,0.2,600,66.820533
|
||||
21.14,35.73,64.27,77.41,4.39,1.62,0.51,16.08,0.68053223097791,0.155793825087198,0.0179378817797627,15.0,5.0,0.08,600,78.5678276838
|
||||
4.77944939105516,41.9496990541703,58.0503009458297,75.09,4.79,3.56,0.32,19.22,0.765481422293248,0.191969636436276,0.0406369499457793,20.0,5.0,0.2,510,84.31
|
||||
8.24,38.05,61.95,82.3,4.73,0.92,1.32,12.05,0.689671931956258,0.109811664641555,0.0095816698489845,60.0,10.0,0.2,650,73.39
|
||||
5.35137948984904,33.22,66.78,80.23,5.17,1.08,0.24,13.28,0.773276829116291,0.124143088620217,0.0115382560851837,10.0,30.0,0.07,600,73.314965
|
||||
43.585255354201,50.5566709253513,49.4433290746487,64.68,5.18,0.11,4.89,25.28,0.961038961038961,0.293135435992579,0.0014577259475218,30.0,10.0,1.0,800,59.675632
|
||||
A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Char
|
||||
16.6457787398883,62.0967741935484,37.9032258064516,68.38,5.6,1.12,0.42,24.05,0.982743492249196,0.263783269961977,0.0140391927464171,0.3,5.0,2.0,650,48.44904
|
||||
8.38943552563439,42.4194460146976,57.5805539853024,82.63,4.08,1.09,0.33,18.56,0.592520876195087,0.16846181774174,0.011306858456804,20.0,5.0,0.2,490,78.69
|
||||
13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,600,57.0752
|
||||
17.44,39.81,60.19,78.08,3.95,0.65,2.87,14.45,0.607069672131148,0.138799948770492,0.0071355386416861,20.0,5.0,0.2,510,79.9
|
||||
24.2816545626776,49.402279677509,50.597720322491,73.99,5.66,1.14,0.49,18.72,0.917961886741452,0.189755372347615,0.0132064178556948,15.0,5.0,0.2,600,66.820533
|
||||
21.14,35.73,64.27,77.41,4.39,1.62,0.51,16.08,0.68053223097791,0.155793825087198,0.0179378817797627,15.0,5.0,0.08,600,78.5678276838
|
||||
4.77944939105516,41.9496990541703,58.0503009458297,75.09,4.79,3.56,0.32,19.22,0.765481422293248,0.191969636436276,0.0406369499457793,20.0,5.0,0.2,510,84.31
|
||||
8.24,38.05,61.95,82.3,4.73,0.92,1.32,12.05,0.689671931956258,0.109811664641555,0.0095816698489845,60.0,10.0,0.2,650,73.39
|
||||
5.35137948984904,33.22,66.78,80.23,5.17,1.08,0.24,13.28,0.773276829116291,0.124143088620217,0.0115382560851837,10.0,30.0,0.07,600,73.314965
|
||||
43.585255354201,50.5566709253513,49.4433290746487,64.68,5.18,0.11,4.89,25.28,0.961038961038961,0.293135435992579,0.0014577259475218,30.0,10.0,1.0,800,59.675632
|
||||
|
|
|
|
@ -1,11 +1,11 @@
|
|||
A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Gas
|
||||
5.78,37.71,62.29,76.25,4.37,0.89,0.46,12.25,0.687737704918033,0.120491803278689,0.0100046838407494,30.0,10.0,13.0,650,6.1
|
||||
25.57,45.86,54.14,65.71,4.92,1.26,2.42,25.69,0.898493380003044,0.293220210013697,0.0164358545122508,20.0,5.0,0.15,600,15.38
|
||||
5.5005500550055,31.5483119906868,68.4516880093132,83.09,4.62,1.07,0.48,10.74,0.667228306655434,0.0969430737754242,0.0110379450853635,15.0,5.0,0.25,600,7.97193
|
||||
16.6457787398883,62.0967741935484,37.9032258064516,68.38,5.6,1.12,0.42,24.05,0.982743492249196,0.263783269961977,0.0140391927464171,0.3,5.0,2.0,750,10.88348
|
||||
9.79,48.99,51.01,73.91,3.98,0.88,0.49,20.75,0.646191313759978,0.210560140711676,0.0102054622417226,20.0,10.0,3.0,550,16.68
|
||||
5.44,41.2,58.8,82.0,4.79,1.51,0.58,11.1,0.7009756097560976,0.1015243902439024,0.0157839721254355,15.0,5.0,0.2,600,11.3
|
||||
38.7730061349693,47.15,52.85,74.47,4.8,1.41,0.98,18.34,0.773465825164496,0.184705250436417,0.0162289704387193,60.0,10.0,1.0,700,8.05
|
||||
15.85645,36.5599621123663,63.4400378876337,80.0704032260806,5.75550432943082,1.22339222461332,0.44301627527664,12.5076839445986,0.8625665560614234,0.1171564345937181,0.013096248608248,15.0,5.0,0.2,600,8.4
|
||||
30.38,59.81,40.19,67.36,4.54,1.67,1.12,25.31,0.808788598574822,0.281806710213777,0.0212504241601629,15.0,5.0,0.08,1100,25.8819005672
|
||||
11.2165120400292,34.8127274862041,65.1872725137959,80.57,5.39,1.01,0.49,12.54,0.802780191138141,0.116730793099168,0.0107448713629674,30.0,30.0,0.14,450,18.87111
|
||||
A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Gas
|
||||
5.78,37.71,62.29,76.25,4.37,0.89,0.46,12.25,0.687737704918033,0.120491803278689,0.0100046838407494,30.0,10.0,13.0,650,6.1
|
||||
25.57,45.86,54.14,65.71,4.92,1.26,2.42,25.69,0.898493380003044,0.293220210013697,0.0164358545122508,20.0,5.0,0.15,600,15.38
|
||||
5.5005500550055,31.5483119906868,68.4516880093132,83.09,4.62,1.07,0.48,10.74,0.667228306655434,0.0969430737754242,0.0110379450853635,15.0,5.0,0.25,600,7.97193
|
||||
16.6457787398883,62.0967741935484,37.9032258064516,68.38,5.6,1.12,0.42,24.05,0.982743492249196,0.263783269961977,0.0140391927464171,0.3,5.0,2.0,750,10.88348
|
||||
9.79,48.99,51.01,73.91,3.98,0.88,0.49,20.75,0.646191313759978,0.210560140711676,0.0102054622417226,20.0,10.0,3.0,550,16.68
|
||||
5.44,41.2,58.8,82.0,4.79,1.51,0.58,11.1,0.7009756097560976,0.1015243902439024,0.0157839721254355,15.0,5.0,0.2,600,11.3
|
||||
38.7730061349693,47.15,52.85,74.47,4.8,1.41,0.98,18.34,0.773465825164496,0.184705250436417,0.0162289704387193,60.0,10.0,1.0,700,8.05
|
||||
15.85645,36.5599621123663,63.4400378876337,80.0704032260806,5.75550432943082,1.22339222461332,0.44301627527664,12.5076839445986,0.8625665560614234,0.1171564345937181,0.013096248608248,15.0,5.0,0.2,600,8.4
|
||||
30.38,59.81,40.19,67.36,4.54,1.67,1.12,25.31,0.808788598574822,0.281806710213777,0.0212504241601629,15.0,5.0,0.08,1100,25.8819005672
|
||||
11.2165120400292,34.8127274862041,65.1872725137959,80.57,5.39,1.01,0.49,12.54,0.802780191138141,0.116730793099168,0.0107448713629674,30.0,30.0,0.14,450,18.87111
|
||||
|
|
|
|
@ -0,0 +1,11 @@
|
|||
A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Tar
|
||||
13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,450,6.2108384
|
||||
14.270205066345,50.41,49.59,58.69,3.14,1.24,0.21,18.12,0.642017379451355,0.231555631283012,0.0181096804030864,120.0,5.0,12.0,650,9.38428
|
||||
11.29,47.12,52.88,65.21,9.71,2.01,0.7,22.37,1.78684250881767,0.257284158871339,0.026420137139352,20.0,15.0,6.0,480,10.14648
|
||||
39.4,53.71,46.29,77.46,6.64,1.57,0.57,20.61,1.0286599535244,0.199554608830364,0.0173730220205821,15.0,30.5,5.0,600,9.6968
|
||||
11.29,47.12,52.88,65.21,9.71,2.01,0.7,22.37,1.78684250881767,0.257284158871339,0.026420137139352,10.0,15.0,6.0,480,10.07076
|
||||
25.56,45.86,54.14,65.71,4.92,1.26,2.42,25.69,0.898493380003044,0.293220210013697,0.0164358545122508,15.0,5.0,0.15,600,8.87
|
||||
11.2165120400292,34.8127274862041,65.1872725137959,80.57,5.39,1.01,0.49,12.54,0.802780191138141,0.116730793099168,0.0107448713629674,30.0,30.0,0.14,650,9.53456
|
||||
17.44,39.81,60.19,78.08,3.95,0.65,2.87,14.45,0.607069672131148,0.138799948770492,0.0071355386416861,20.0,40.0,6.0,500,6.1334
|
||||
15.84092126406,50.41,49.59,58.69,3.14,1.24,0.21,18.12,0.642017379451355,0.231555631283012,0.0181096804030864,120.0,5.0,6.0,650,6.504628
|
||||
6.17,46.5,53.5,76.14,3.06,1.06,0.24,19.5,0.482269503546099,0.192080378250591,0.0119329055499268,30.0,10.0,0.07,650,8.585445
|
|
|
@ -1,11 +1,11 @@
|
|||
A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Water
|
||||
4.54,47.35,52.65,70.41,6.95,1.15,0.43,21.06,1.18449083936941,0.224328930549638,0.0139996347921359,30.0,20.0,2.0,600,10.046517999999995
|
||||
13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,550,8.1107936
|
||||
11.3260262196432,54.9321376635967,45.0678623364033,63.75,4.39,1.25,0.55,30.11,0.826352941176471,0.354235294117647,0.0168067226890756,20.0,5.0,3.0,376,5.59
|
||||
10.5,42.36,57.64,82.1381852776787,5.01843263170413,0.820549411344988,0.249732429539779,30.9311452015698,0.733169248588389,0.282430867236138,0.0085627417319903,20.0,5.0,0.2,510,6.15
|
||||
10.5,42.36,57.64,82.1381852776787,5.01843263170413,0.820549411344988,0.249732429539779,30.9311452015698,0.733169248588389,0.282430867236138,0.0085627417319903,60.0,10.0,0.15,650,5.44968
|
||||
8.8735776177054,38.1,61.9,78.54,5.28,1.2,0.39,14.59,0.80672268907563,0.139323911382735,0.0130961475499291,20.0,5.0,0.07,600,4.368024
|
||||
24.2816545626776,49.402279677509,50.597720322491,73.99,5.66,1.14,0.49,18.72,0.917961886741452,0.189755372347615,0.0132064178556948,15.0,5.0,0.2,600,7.078245
|
||||
16.6326530612245,36.82,63.18,83.04,5.39,1.48,0.64,9.45,0.778901734104046,0.0853504335260115,0.0152766308835673,60.0,10.0,1.0,500,5.58
|
||||
4.4152621238755,29.6624837732583,70.3375162267417,81.78,4.79,1.1,0.38,11.95,0.702861335289802,0.10959280997799,0.0115291898123886,30.0,20.0,0.85,700,12.071072
|
||||
11.9159836065574,51.1573804815633,48.8426195184367,83.22,3.89,2.72,0.45,20.21,0.560922855082913,0.182137707281903,0.0280152435884231,30.0,5.0,0.2,510,1.1399999999999952
|
||||
A,V,FC,C,H,N,S,O,H/C,O/C,N/C,Rt,Hr,dp,T,Water
|
||||
4.54,47.35,52.65,70.41,6.95,1.15,0.43,21.06,1.18449083936941,0.224328930549638,0.0139996347921359,30.0,20.0,2.0,600,10.046517999999995
|
||||
13.5,54.1,45.9,71.6,5.8,1.4,0.9,20.3,0.972067039106145,0.212639664804469,0.0167597765363128,30.0,10.0,0.84,550,8.1107936
|
||||
11.3260262196432,54.9321376635967,45.0678623364033,63.75,4.39,1.25,0.55,30.11,0.826352941176471,0.354235294117647,0.0168067226890756,20.0,5.0,3.0,376,5.59
|
||||
10.5,42.36,57.64,82.1381852776787,5.01843263170413,0.820549411344988,0.249732429539779,30.9311452015698,0.733169248588389,0.282430867236138,0.0085627417319903,20.0,5.0,0.2,510,6.15
|
||||
10.5,42.36,57.64,82.1381852776787,5.01843263170413,0.820549411344988,0.249732429539779,30.9311452015698,0.733169248588389,0.282430867236138,0.0085627417319903,60.0,10.0,0.15,650,5.44968
|
||||
8.8735776177054,38.1,61.9,78.54,5.28,1.2,0.39,14.59,0.80672268907563,0.139323911382735,0.0130961475499291,20.0,5.0,0.07,600,4.368024
|
||||
24.2816545626776,49.402279677509,50.597720322491,73.99,5.66,1.14,0.49,18.72,0.917961886741452,0.189755372347615,0.0132064178556948,15.0,5.0,0.2,600,7.078245
|
||||
16.6326530612245,36.82,63.18,83.04,5.39,1.48,0.64,9.45,0.778901734104046,0.0853504335260115,0.0152766308835673,60.0,10.0,1.0,500,5.58
|
||||
4.4152621238755,29.6624837732583,70.3375162267417,81.78,4.79,1.1,0.38,11.95,0.702861335289802,0.10959280997799,0.0115291898123886,30.0,20.0,0.85,700,12.071072
|
||||
11.9159836065574,51.1573804815633,48.8426195184367,83.22,3.89,2.72,0.45,20.21,0.560922855082913,0.182137707281903,0.0280152435884231,30.0,5.0,0.2,510,1.1399999999999952
|
||||
|
|
|
|
@ -1,6 +1,6 @@
|
|||
import pandas as pd
|
||||
# 读取Excel文件
|
||||
file_path = "D:\\project\\ai_station\\meirejie\\data\\char_data.csv" # 替换为你的Excel文件路径
|
||||
file_path = "/home/xiazj/ai-station-code/meirejie/data/tar_data.csv" # 替换为你的Excel文件路径
|
||||
df = pd.read_csv(file_path)
|
||||
# 随机抽取10条数据
|
||||
test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同
|
||||
|
@ -8,45 +8,45 @@ test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结
|
|||
columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Tar']
|
||||
test_set = test_set[columns]
|
||||
# 保存测试集到新的Excel文件
|
||||
test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\char_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引
|
||||
test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/tar_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引
|
||||
print("测试集已保存为 char_data_test.csv")
|
||||
|
||||
|
||||
|
||||
file_path = "D:\\project\\ai_station\\meirejie\\data\\gas_data.csv" # 替换为你的Excel文件路径
|
||||
file_path = "/home/xiazj/ai-station-code/meirejie/data/gas_data.csv" # 替换为你的Excel文件路径
|
||||
df = pd.read_csv(file_path)
|
||||
# 随机抽取10条数据
|
||||
test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同
|
||||
columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Gas']
|
||||
test_set = test_set[columns]
|
||||
# 保存测试集到新的Excel文件
|
||||
test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\gas_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引
|
||||
test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/gas_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引
|
||||
|
||||
print("测试集已保存为 gas_data_test.csv")
|
||||
|
||||
|
||||
|
||||
|
||||
file_path = "D:\\project\\ai_station\\meirejie\\data\\water_data.csv" # 替换为你的Excel文件路径
|
||||
file_path = "/home/xiazj/ai-station-code/meirejie/data/water_data.csv" # 替换为你的Excel文件路径
|
||||
df = pd.read_csv(file_path)
|
||||
# 随机抽取10条数据
|
||||
test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同
|
||||
columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Water']
|
||||
test_set = test_set[columns]
|
||||
# 保存测试集到新的Excel文件
|
||||
test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\water_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引
|
||||
test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/water_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引
|
||||
|
||||
print("测试集已保存为 water_data_test.csv")
|
||||
|
||||
|
||||
|
||||
file_path = "D:\\project\\ai_station\\meirejie\\data\\char_data.csv" # 替换为你的Excel文件路径
|
||||
file_path = "/home/xiazj/ai-station-code/meirejie/data/char_data.csv" # 替换为你的Excel文件路径
|
||||
df = pd.read_csv(file_path)
|
||||
# 随机抽取10条数据
|
||||
test_set = df.sample(n=10, random_state=1) # random_state保证每次抽样结果相同
|
||||
columns = ['A', 'V', 'FC', 'C', 'H', 'N', 'S', 'O', 'H/C', 'O/C', 'N/C', 'Rt','Hr', 'dp', 'T','Char']
|
||||
test_set = test_set[columns]
|
||||
# 保存测试集到新的Excel文件
|
||||
test_set.to_csv('D:\\project\\ai_station\\meirejie\\data\\char_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引
|
||||
test_set.to_csv('/home/xiazj/ai-station-code/meirejie/data/char_data_test.csv', index=False) # 保存为test_set.xlsx,不保存索引
|
||||
|
||||
print("测试集已保存为 char_data_test.csv")
|
830
run.py
830
run.py
|
@ -1,18 +1,22 @@
|
|||
import sys
|
||||
|
||||
import io
|
||||
from fastapi import FastAPI,File, UploadFile,Form,Query
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse,JSONResponse
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
from pydantic import BaseModel, validator
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import asyncio
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pickle
|
||||
import cv2
|
||||
import copy
|
||||
import base64
|
||||
# 获取当前脚本所在目录
|
||||
print("Current working directory:", os.getcwd())
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -25,7 +29,7 @@ from wudingpv.taihuyuan_roof.manet.model.resunet import resUnetpamcarb as roof_r
|
|||
from wudingpv.predictandeval_util import segmentation
|
||||
from guangfufadian import model_base as guangfufadian_model_base
|
||||
from fenglifadian import model_base as fenglifadian_model_base
|
||||
from work_util import prepare_data,model_deal,params,data_util,post_model
|
||||
from work_util import prepare_data,model_deal,params,data_util,post_model,sam_deal
|
||||
from work_util.logger import logger
|
||||
import joblib
|
||||
import mysql.connector
|
||||
|
@ -33,7 +37,10 @@ import uuid
|
|||
import json
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from segment_anything_model import sam_annotator
|
||||
from segment_anything_model.sam_config import sam_config, sam_api_config
|
||||
from segment_anything_model.segment_anything import sam_model_registry, SamPredictor
|
||||
import traceback
|
||||
version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
app = FastAPI()
|
||||
|
@ -77,6 +84,17 @@ windfd_args = fenglifadian_model_base.fenglifadian_Args()
|
|||
windfd_model_path = os.path.join(windfd_args.checkpoints,'Crossformer_Wind_farm_il192_ol12_sl6_win2_fa10_dm256_nh4_el3_itr0/checkpoint.pth') # 修改为实际模型路径
|
||||
windfd_model = fenglifadian_model_base.ModelInference(windfd_model_path, windfd_args)
|
||||
|
||||
|
||||
# 模型加载
|
||||
checkpoint_path = os.path.join(current_dir,'segment_anything_model/weights/vit_b.pth')
|
||||
sam = sam_model_registry["vit_b"](checkpoint=checkpoint_path)
|
||||
device = "cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu"
|
||||
_ = sam.to(device=device)
|
||||
sam_predictor = SamPredictor(sam)
|
||||
print(f"SAM模型已加载,使用设备: {device}")
|
||||
|
||||
|
||||
|
||||
# 将 /root/app 目录挂载为静态文件
|
||||
app.mount("/files", StaticFiles(directory="/root/app"), name="files")
|
||||
|
||||
|
@ -86,6 +104,25 @@ async def read_root():
|
|||
return {"message": message}
|
||||
|
||||
|
||||
# 首页
|
||||
# 获取数据界面资源信息
|
||||
@app.get("/ai-station-api/index/show")
|
||||
async def get_source_index_info():
|
||||
sql = "SELECT id,application_name, describe_data, img_url FROM app_shouye"
|
||||
data = data_util.fetch_data(sql)
|
||||
if data is None:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"获取信息列表失败",
|
||||
"data":None
|
||||
}
|
||||
else:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":data}
|
||||
|
||||
|
||||
|
||||
|
||||
# 获取数据界面资源信息
|
||||
@app.get("/ai-station-api/data_source/show")
|
||||
|
@ -188,7 +225,7 @@ type = {tar,char,gas,water}
|
|||
"""
|
||||
@app.get("/ai-station-api/mrj_feature/show")
|
||||
async def get_mrj_feature_info(type:str = None):
|
||||
sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale FROM meirejie_features where use_type = %s;"
|
||||
sql = "SELECT type, chinese_name, col_name, data_type, unit, data_scale,best_data FROM meijitancailiao_features where use_type = %s;"
|
||||
data = data_util.fetch_data_with_param(sql,(type,))
|
||||
if data is None:
|
||||
return {
|
||||
|
@ -235,7 +272,7 @@ async def mjt_models_predict_ssa(content: post_model.Zongbiaomianji):
|
|||
new_order = ["A", "VM", "K/C", "MM", "AT", "At", "Rt"]
|
||||
meijiegou_test_content = meijiegou_test_content.reindex(columns=new_order)
|
||||
ssa_result = model_deal.pred_single_ssa(meijiegou_test_content)
|
||||
logger.info("Root endpoint was accessed")
|
||||
# logger.info("Root endpoint was accessed")
|
||||
if ssa_result is None:
|
||||
return {
|
||||
"success":False,
|
||||
|
@ -273,9 +310,14 @@ async def upload_file(file: UploadFile = File(...),type: str = Form(...), ):
|
|||
@app.get("/ai-station-api/mjt_multi_ssa/predict")
|
||||
async def mjt_multi_ssa_pred(model:str = None, path:str = None):
|
||||
data = model_deal.get_excel_ssa(model, path)
|
||||
return {"success":True,
|
||||
if data['status'] == True:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":data}
|
||||
"data":data['reason']}
|
||||
else:
|
||||
return {"success":False,
|
||||
"msg":data['reason'],
|
||||
"data":None}
|
||||
|
||||
|
||||
|
||||
|
@ -308,9 +350,14 @@ async def mjt_models_predict_tpv(content: post_model.Zongbiaomianji):
|
|||
@app.get("/ai-station-api/mjt_multi_tpv/predict")
|
||||
async def mjt_multi_tpv_pred(model:str = None, path:str = None):
|
||||
data = model_deal.get_excel_tpv(model, path)
|
||||
return {"success":True,
|
||||
if data['status'] == True:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":data}
|
||||
"data":data['reason']}
|
||||
else:
|
||||
return {"success":False,
|
||||
"msg":data['reason'],
|
||||
"data":None}
|
||||
|
||||
#==========================煤基碳材料-煤炭材料应用细节接口==================================
|
||||
@app.post("/ai-station-api/mjt_models_meitan/predict")
|
||||
|
@ -339,9 +386,14 @@ async def mjt_models_predict_meitan(content: post_model.Meitan):
|
|||
@app.get("/ai-station-api/mjt_multi_meitan/predict")
|
||||
async def mjt_multi_meitan_pred(model:str = None, path:str = None):
|
||||
data = model_deal.get_excel_meitan(model, path)
|
||||
return {"success":True,
|
||||
if data['status'] == True:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":data}
|
||||
"data":data['reason']}
|
||||
else:
|
||||
return {"success":False,
|
||||
"msg":data['reason'],
|
||||
"data":None}
|
||||
|
||||
|
||||
#==========================煤基碳材料-煤沥青应用细节接口==================================
|
||||
|
@ -371,10 +423,14 @@ async def mjt_models_predict_meiliqing(content: post_model.Meitan):
|
|||
@app.get("/ai-station-api/mjt_multi_meiliqing/predict")
|
||||
async def mjt_multi_meiliqing_pred(model:str = None, path:str = None):
|
||||
data = model_deal.get_excel_meiliqing(model, path)
|
||||
return {"success":True,
|
||||
if data['status'] == True:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":data}
|
||||
|
||||
"data":data['reason']}
|
||||
else:
|
||||
return {"success":False,
|
||||
"msg":data['reason'],
|
||||
"data":None}
|
||||
|
||||
|
||||
|
||||
|
@ -428,6 +484,176 @@ async def mjt_multi_meiliqing_pred(model:str = None, path:str = None, type:int =
|
|||
}
|
||||
|
||||
|
||||
#===============================煤热解-tar =======================================================
|
||||
"""
|
||||
模型综合分析
|
||||
"""
|
||||
@app.post("/ai-station-api/mrj_models_tar/predict")
|
||||
async def mrj_models_predict_tar(content: post_model.Meirejie):
|
||||
# 处理接收到的字典数据
|
||||
test_content = pd.DataFrame([content.model_dump()])
|
||||
test_content= test_content.rename(columns={"H_C":"H/C"})
|
||||
test_content= test_content.rename(columns={"O_C":"O/C"})
|
||||
test_content= test_content.rename(columns={"N_C":"N/C"})
|
||||
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
|
||||
test_content = test_content.reindex(columns=new_order)
|
||||
tmp_result = model_deal.pred_single_tar(test_content)
|
||||
if tmp_result is None:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"获取信息列表失败",
|
||||
"data":None
|
||||
}
|
||||
else:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":tmp_result}
|
||||
|
||||
|
||||
|
||||
"""
|
||||
批量预测接口
|
||||
"""
|
||||
@app.get("/ai-station-api/mrj_multi_tar/predict")
|
||||
async def mrj_multi_tar_pred(model:str = None, path:str = None):
|
||||
data = model_deal.get_excel_tar(model, path)
|
||||
if data['status'] == True:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":data['reason']}
|
||||
else:
|
||||
return {"success":False,
|
||||
"msg":data['reason'],
|
||||
"data":None}
|
||||
|
||||
|
||||
|
||||
|
||||
#===============================煤热解-char =======================================================
|
||||
"""
|
||||
模型综合分析
|
||||
"""
|
||||
@app.post("/ai-station-api/mrj_models_char/predict")
|
||||
async def mrj_models_predict_char(content: post_model.Meirejie):
|
||||
# 处理接收到的字典数据
|
||||
test_content = pd.DataFrame([content.model_dump()])
|
||||
test_content= test_content.rename(columns={"H_C":"H/C"})
|
||||
test_content= test_content.rename(columns={"O_C":"O/C"})
|
||||
test_content= test_content.rename(columns={"N_C":"N/C"})
|
||||
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
|
||||
test_content = test_content.reindex(columns=new_order)
|
||||
tmp_result = model_deal.pred_single_char(test_content)
|
||||
if tmp_result is None:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"获取信息列表失败",
|
||||
"data":None
|
||||
}
|
||||
else:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":tmp_result}
|
||||
|
||||
"""
|
||||
批量预测接口
|
||||
"""
|
||||
@app.get("/ai-station-api/mrj_multi_char/predict")
|
||||
async def mrj_multi_char_pred(model:str = None, path:str = None):
|
||||
data = model_deal.get_excel_char(model, path)
|
||||
if data['status'] == True:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":data['reason']}
|
||||
else:
|
||||
return {"success":False,
|
||||
"msg":data['reason'],
|
||||
"data":None}
|
||||
|
||||
|
||||
#===============================煤热解-water =======================================================
|
||||
"""
|
||||
模型综合分析
|
||||
"""
|
||||
@app.post("/ai-station-api/mrj_models_water/predict")
|
||||
async def mrj_models_predict_water(content: post_model.Meirejie):
|
||||
# 处理接收到的字典数据
|
||||
test_content = pd.DataFrame([content.model_dump()])
|
||||
test_content= test_content.rename(columns={"H_C":"H/C"})
|
||||
test_content= test_content.rename(columns={"O_C":"O/C"})
|
||||
test_content= test_content.rename(columns={"N_C":"N/C"})
|
||||
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
|
||||
test_content = test_content.reindex(columns=new_order)
|
||||
tmp_result = model_deal.pred_single_water(test_content)
|
||||
if tmp_result is None:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"获取信息列表失败",
|
||||
"data":None
|
||||
}
|
||||
else:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":tmp_result}
|
||||
|
||||
"""
|
||||
批量预测接口
|
||||
"""
|
||||
@app.get("/ai-station-api/mrj_multi_water/predict")
|
||||
async def mrj_multi_water_pred(model:str = None, path:str = None):
|
||||
data = model_deal.get_excel_water(model, path)
|
||||
if data['status'] == True:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":data['reason']}
|
||||
else:
|
||||
return {"success":False,
|
||||
"msg":data['reason'],
|
||||
"data":None}
|
||||
|
||||
|
||||
#===============================煤热解-gas =======================================================
|
||||
"""
|
||||
模型综合分析
|
||||
"""
|
||||
@app.post("/ai-station-api/mrj_models_gas/predict")
|
||||
async def mrj_models_predict_gas(content: post_model.Meirejie):
|
||||
# 处理接收到的字典数据
|
||||
test_content = pd.DataFrame([content.model_dump()])
|
||||
test_content= test_content.rename(columns={"H_C":"H/C"})
|
||||
test_content= test_content.rename(columns={"O_C":"O/C"})
|
||||
test_content= test_content.rename(columns={"N_C":"N/C"})
|
||||
new_order = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T"]
|
||||
test_content = test_content.reindex(columns=new_order)
|
||||
tmp_result = model_deal.pred_single_gas(test_content)
|
||||
if tmp_result is None:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"获取信息列表失败",
|
||||
"data":None
|
||||
}
|
||||
else:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":tmp_result}
|
||||
|
||||
"""
|
||||
批量预测接口
|
||||
"""
|
||||
@app.get("/ai-station-api/mrj_multi_gas/predict")
|
||||
async def mrj_multi_gas_pred(model:str = None, path:str = None):
|
||||
data = model_deal.get_excel_gas(model, path)
|
||||
if data['status'] == True:
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"data":data['reason']}
|
||||
else:
|
||||
return {"success":False,
|
||||
"msg":data['reason'],
|
||||
"data":None}
|
||||
|
||||
|
||||
|
||||
|
||||
#================================ 地貌识别 ====================================================
|
||||
"""
|
||||
上传图片 , type = dimaoshibie
|
||||
|
@ -458,6 +684,7 @@ async def upload_image(file: UploadFile = File(...),type: str = Form(...), ):
|
|||
"msg":"获取信息成功",
|
||||
"data":{"location": file_location}}
|
||||
|
||||
|
||||
"""
|
||||
图片地貌识别
|
||||
"""
|
||||
|
@ -753,7 +980,6 @@ async def get_ch4_data(path:str = None, page: int = Query(1, ge=1), page_size: i
|
|||
csv上传
|
||||
"""
|
||||
# @app.post("/ai-station-api/document/upload") type = "ch4"
|
||||
|
||||
"""
|
||||
预测, type
|
||||
"""
|
||||
|
@ -771,7 +997,6 @@ async def get_ch4_predict(path:str=None,start_time:str=None, end_time:str = None
|
|||
"msg":"获取信息成功",
|
||||
"data":data['reason']}
|
||||
|
||||
|
||||
#========================================光伏预测==========================================================
|
||||
"""
|
||||
返回显示列表
|
||||
|
@ -982,6 +1207,8 @@ csv上传
|
|||
"""
|
||||
# @app.post("/ai-station-api/document/upload") type = "wind_electric"
|
||||
|
||||
|
||||
|
||||
"""
|
||||
预测, type
|
||||
"""
|
||||
|
@ -1000,9 +1227,576 @@ async def get_wind_electri_predict(path:str=None,start_time:str=None, end_time:s
|
|||
"data":data['reason']}
|
||||
|
||||
|
||||
#======================== SAM =============================================================
|
||||
"""
|
||||
文件上传
|
||||
1、 创建inputs目录,outputs目录
|
||||
2、 将图片上传到iuputs目录
|
||||
4、 加载当前图像
|
||||
5、 返回当前图像路径给前端
|
||||
"""
|
||||
@app.post("/ai-station-api/sam-image/upload")
|
||||
async def upload_sam_image(file: UploadFile = File(...),type: str = Form(...), ):
|
||||
if not data_util.allowed_file(file.filename):
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"图片必须以 '.jpg', '.jpeg', '.png', '.tif' 结尾",
|
||||
"data":None
|
||||
}
|
||||
if file.size > param.MAX_FILE_SAM_SIZE:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"图片大小不能大于10MB",
|
||||
"data":None
|
||||
}
|
||||
upload_dir = os.path.join(current_dir,'tmp',type, str(uuid.uuid4()))
|
||||
if not os.path.exists(upload_dir):
|
||||
os.makedirs(upload_dir )
|
||||
input_dir = os.path.join(upload_dir,'input')
|
||||
os.makedirs(input_dir)
|
||||
output_dir = os.path.join(upload_dir,'output')
|
||||
os.makedirs(output_dir)
|
||||
temp_dir = os.path.join(upload_dir,'temp')
|
||||
os.makedirs(temp_dir)
|
||||
# 将文件保存到指定目录
|
||||
file_location = os.path.join(input_dir , file.filename)
|
||||
with open(file_location, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
config_path = os.path.join(upload_dir,'model_params.pickle')
|
||||
# # 初始化配置 , 每次上传图片时,会创建一个新的配置文件
|
||||
config = copy.deepcopy(sam_config)
|
||||
api_config = copy.deepcopy(sam_api_config)
|
||||
|
||||
config['input_dir'] = input_dir
|
||||
config['output_dir'] = output_dir
|
||||
config['image_files'] = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
|
||||
config['current_index'] = 0
|
||||
config['filename'] = config['image_files'][config['current_index']]
|
||||
image_path = os.path.join(config['input_dir'], config['filename'])
|
||||
config['image'] = cv2.imread(image_path)
|
||||
config['image_rgb'] = cv2.cvtColor(config['image'].copy(), cv2.COLOR_BGR2RGB)
|
||||
# 配置类获取参数信息
|
||||
api_config['output_dir'] = output_dir
|
||||
# # 重置pickle中的信息
|
||||
# 重置API中的class_annotations,保留类别但清除掩码和点
|
||||
config, api_config = sam_deal.reset_class_annotations(config,api_config)
|
||||
config = sam_deal.reset_annotation(config)
|
||||
|
||||
save_data = (config,api_config)
|
||||
with open(config_path, 'wb') as file:
|
||||
pickle.dump(save_data, file)
|
||||
# 将图片拷贝到temp目录下
|
||||
pil_img = Image.fromarray(config['image_rgb'])
|
||||
tmp_path = os.path.join(temp_dir,'output_image.jpg')
|
||||
pil_img.save(tmp_path)
|
||||
encoded_string = sam_deal.load_tmp_image(tmp_path)
|
||||
return {"success":True,
|
||||
"msg":"获取信息成功",
|
||||
"image": JSONResponse(content={"image_data": encoded_string}),
|
||||
#"input_dir": input_dir,
|
||||
#'output_dir': output_dir,
|
||||
#'temp_dir': temp_dir,
|
||||
'file_path': upload_dir}
|
||||
|
||||
"""
|
||||
添加分类
|
||||
添加分类,并将分类设置为最新的添加的分类;
|
||||
|
||||
要求:针对返回current_index,将列表默认成选择current_index对应的分类
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_class/create")
|
||||
async def sam_class_set(
|
||||
class_name: str = None,
|
||||
color: Optional[List[int]] = Query(None, description="list of RGB color"),
|
||||
path: str = None
|
||||
):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
result = sam_deal.add_class(loaded_data,class_name,color)
|
||||
if result['status'] == True:
|
||||
loaded_data = result['reason']
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":result['reason'],
|
||||
"data":None
|
||||
}
|
||||
loaded_data['class_index'] = loaded_data['class_names'].index(class_name)
|
||||
r, g, b = [int(c) for c in color]
|
||||
bgr_color = (b, g, r)
|
||||
result, api_config = sam_deal.set_current_class(loaded_data, api_config, loaded_data['class_index'], color=bgr_color)
|
||||
# 更新配置内容
|
||||
sam_deal.save_model(loaded_data,api_config,path)
|
||||
tmp_path = os.path.join(path,'temp/output_image.jpg')
|
||||
encoded_string = sam_deal.load_tmp_image(tmp_path)
|
||||
return {"success":True,
|
||||
"msg":f"已添加类别: {class_name}, 颜色: {color}",
|
||||
"image": JSONResponse(content={"image_data": encoded_string}),
|
||||
"data":{"class_name_list": loaded_data['class_names'],
|
||||
"current_index": loaded_data['class_index'],
|
||||
"class_dict":loaded_data['class_colors'],
|
||||
}}
|
||||
|
||||
|
||||
|
||||
"""
|
||||
选择颜色,
|
||||
current_index : 下拉列表中的分类索引
|
||||
rgb_color :
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_color/select")
|
||||
async def set_sam_color(
|
||||
current_index: int = None,
|
||||
rgb_color: List[int] = Query(None, description="list of RGB color"),
|
||||
path: str = None
|
||||
):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
r, g, b = [int(c) for c in rgb_color]
|
||||
bgr_color = (b, g, r)
|
||||
data, api = sam_deal.set_class_color(loaded_data, api_config, current_index, bgr_color)
|
||||
result, api_config = sam_deal.set_current_class(data, api, current_index, color=bgr_color)
|
||||
sam_deal.save_model(data,api,path)
|
||||
img = sam_deal.refresh_image(data,api,path)
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image": JSONResponse(content={"image_data": encoded_string}),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"image":None
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
选择分类
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_classs/change")
|
||||
async def on_class_selected(class_index : int = None,path: str = None):
|
||||
# 加载配置内容
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
result, api_config = sam_deal.set_current_class(loaded_data, api_config, class_index, color=None)
|
||||
# loaded_data['class_index'] = class_index
|
||||
sam_deal.save_model(loaded_data,api_config,path)
|
||||
if result:
|
||||
img = sam_deal.refresh_image(loaded_data,api_config,path)
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image":JSONResponse(content={"image_data": encoded_string})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"image":None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"分类标签识别错误",
|
||||
"image":None
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
删除分类, 前端跳转为normal ,index=0
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_class/delete")
|
||||
async def sam_remove_class(path:str=None,select_index:int=None):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
class_name = loaded_data['class_names'][select_index]
|
||||
loaded_data,api_config = sam_deal.remove_class(loaded_data,api_config,class_name)
|
||||
sam_deal.save_model(loaded_data,api_config,path)
|
||||
img = sam_deal.refresh_image(loaded_data,api_config,path)
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image":JSONResponse(content={"image_data": encoded_string})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"image":None
|
||||
}
|
||||
|
||||
"""
|
||||
添加标注点 -左键
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_point/left_add")
|
||||
async def left_mouse_down(x:int=None,y:int=None,path:str=None):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
if not api_config['current_class']:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"请先选择一个分类,在添加标点之前",
|
||||
"image":None
|
||||
}
|
||||
is_foreground = True
|
||||
result = sam_deal.add_annotation_point(api_config,x,y,is_foreground)
|
||||
if result['status']== False:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":result['reason'],
|
||||
"image":None
|
||||
}
|
||||
api_config = result['api']
|
||||
sam_deal.save_model(loaded_data,api_config,path)
|
||||
img = sam_deal.refresh_image(loaded_data,api_config,path)
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image":JSONResponse(content={"image_data": encoded_string})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"image":None
|
||||
}
|
||||
|
||||
"""
|
||||
添加标注点 -右键
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_point/right_add")
|
||||
async def right_mouse_down(x:int=None,y:int=None,path:str=None):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
if not api_config['current_class']:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"请先选择一个分类,在添加标点之前",
|
||||
"image":None
|
||||
}
|
||||
is_foreground = False
|
||||
result = sam_deal.add_annotation_point(api_config,x,y,is_foreground)
|
||||
if result['status']== False:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":result['reason'],
|
||||
"image":None
|
||||
}
|
||||
api_config = result['api']
|
||||
sam_deal.save_model(loaded_data,api_config,path)
|
||||
img = sam_deal.refresh_image(loaded_data,api_config,path)
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image":JSONResponse(content={"image_data": encoded_string})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"image":None
|
||||
}
|
||||
|
||||
"""
|
||||
删除上一个点
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_point/delete_last")
|
||||
async def sam_delete_last_point(path:str=None):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
result = sam_deal.delete_last_point(api_config)
|
||||
if result['status'] == True:
|
||||
api_config = result['reason']
|
||||
sam_deal.save_model(loaded_data,api_config,path)
|
||||
img = sam_deal.refresh_image(loaded_data,api_config,path)
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image":JSONResponse(content={"image_data": encoded_string})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"data":None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":result['reason'],
|
||||
"data":None
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
删除所有点
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_point/delete_all")
|
||||
async def sam_clear_all_point(path:str=None):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
result = sam_deal.reset_current_class_points(api_config)
|
||||
if result['status'] == True:
|
||||
api_config = result['reason']
|
||||
sam_deal.save_model(loaded_data,api_config,path)
|
||||
img = sam_deal.refresh_image(loaded_data,api_config,path)
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image":JSONResponse(content={"image_data": encoded_string})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"image":None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":result['reason'],
|
||||
"image":None
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
模型预测
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_model/predict")
|
||||
async def sam_predict_mask(path:str=None):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
class_data = api_config['class_annotations'].get(api_config['current_class'], {})
|
||||
if not class_data.get('points'):
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"请在预测前添加至少一个预测样本点",
|
||||
"data":None
|
||||
}
|
||||
else:
|
||||
loaded_data = sam_deal.reset_annotation(loaded_data)
|
||||
# 将标注点添加到loaded_data 中
|
||||
for i, (x, y) in enumerate(class_data['points']):
|
||||
is_foreground = class_data['point_types'][i]
|
||||
loaded_data = sam_deal.add_point(loaded_data, x, y, is_foreground=is_foreground)
|
||||
try:
|
||||
result = sam_deal.predict_mask(loaded_data,sam_predictor)
|
||||
if result['status'] == False:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":result['reason'],
|
||||
"data":None}
|
||||
|
||||
result = result['reason']
|
||||
loaded_data = result['data']
|
||||
class_data['masks'] = [np.array(mask, dtype=np.uint8) for mask in result['masks']]
|
||||
class_data['scores'] = result['scores']
|
||||
class_data['selected_mask_index'] = result['selected_index']
|
||||
if result['selected_index'] >= 0:
|
||||
class_data['selected_mask'] = class_data['masks'][result['selected_index']]
|
||||
logger.info(f"predict: Predicted {len(result['masks'])} masks, selected index: {result['selected_index']}")
|
||||
sam_deal.save_model(loaded_data,api_config,path)
|
||||
img = sam_deal.refresh_image(loaded_data,api_config,path)
|
||||
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image":JSONResponse(content={"image_data": encoded_string})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"image":None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"predict: Error during prediction: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return {
|
||||
"success":False,
|
||||
"msg":f"predict: Error during prediction: {str(e)}",
|
||||
"data":None}
|
||||
|
||||
|
||||
"""
|
||||
清除所有信息
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_model/clear_all")
|
||||
async def sam_reset_annotation(path:str=None):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
loaded_data,api_config = sam_deal.reset_annotation_all(loaded_data,api_config)
|
||||
img = sam_deal.refresh_image(loaded_data,api_config,path)
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image":JSONResponse(content={"image_data": encoded_string})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"image":None
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
保存预测分类
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_model/save_stage")
|
||||
async def sam_add_to_class(path:str=None,class_index:int=None):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
class_name = loaded_data['class_names'][class_index]
|
||||
result = sam_deal.add_to_class(api_config,class_name)
|
||||
if result['status'] == True:
|
||||
api_config = result['reason']
|
||||
sam_deal.save_model(loaded_data,api_config,path)
|
||||
img = sam_deal.refresh_image(loaded_data,api_config,path)
|
||||
if img['status'] == True:
|
||||
encoded_string = sam_deal.load_tmp_image(img['reason'])
|
||||
return {
|
||||
"success":True,
|
||||
"msg":"",
|
||||
"image":JSONResponse(content={"image_data": encoded_string})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":img['reason'],
|
||||
"image":None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success":False,
|
||||
"msg":result['reason'],
|
||||
"image":None
|
||||
}
|
||||
|
||||
|
||||
|
||||
"""
|
||||
保存结果
|
||||
"""
|
||||
@app.get("/ai-station-api/sam_model/save")
|
||||
async def sam_save_annotation(path:str=None):
|
||||
loaded_data,api_config = sam_deal.load_model(path)
|
||||
|
||||
if not api_config['output_dir']:
|
||||
logger.info("save_annotation: Output directory not set")
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"save_annotation: Output directory not set",
|
||||
"data":None
|
||||
}
|
||||
has_annotations = False
|
||||
for class_name, class_data in api_config['class_annotations'].items():
|
||||
if 'final_mask' in class_data and class_data['final_mask'] is not None:
|
||||
has_annotations = True
|
||||
break
|
||||
if not has_annotations:
|
||||
logger.info("save_annotation: No final masks to save")
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"save_annotation: No final masks to save",
|
||||
"data":None
|
||||
}
|
||||
image_info = sam_deal.get_image_info(loaded_data)
|
||||
if not image_info:
|
||||
logger.info("save_annotation: No image info available")
|
||||
return {
|
||||
"success":False,
|
||||
"msg":"save_annotation: No image info available",
|
||||
"data":None
|
||||
}
|
||||
image_basename = os.path.splitext(image_info['filename'])[0]
|
||||
annotation_dir = os.path.join(api_config['output_dir'], image_basename)
|
||||
os.makedirs(annotation_dir, exist_ok=True)
|
||||
saved_files = []
|
||||
orig_img = loaded_data['image']
|
||||
original_img_path = os.path.join(annotation_dir, f"{image_basename}.jpg")
|
||||
cv2.imwrite(original_img_path, orig_img)
|
||||
saved_files.append(original_img_path)
|
||||
vis_img = orig_img.copy()
|
||||
img_height, img_width = orig_img.shape[:2]
|
||||
labelme_data = {
|
||||
"version": "5.1.1",
|
||||
"flags": {},
|
||||
"shapes": [],
|
||||
"imagePath": f"{image_basename}.jpg",
|
||||
"imageData": None,
|
||||
"imageHeight": img_height,
|
||||
"imageWidth": img_width
|
||||
}
|
||||
for class_name, class_data in api_config['class_annotations'].items():
|
||||
if 'final_mask' in class_data and class_data['final_mask'] is not None:
|
||||
color = api_config['class_colors'].get(class_name, (0, 255, 0))
|
||||
vis_mask = class_data['final_mask'].copy()
|
||||
color_mask = np.zeros_like(vis_img)
|
||||
color_mask[vis_mask > 0] = color
|
||||
vis_img = cv2.addWeighted(vis_img, 1.0, color_mask, 0.5, 0)
|
||||
binary_mask = (class_data['final_mask'] > 0).astype(np.uint8)
|
||||
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours:
|
||||
epsilon = 0.0001 * cv2.arcLength(contour, True)
|
||||
approx_contour = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = [[float(point[0][0]), float(point[0][1])] for point in approx_contour]
|
||||
if len(points) >= 3:
|
||||
shape_data = {
|
||||
"label": class_name,
|
||||
"points": points,
|
||||
"group_id": None,
|
||||
"shape_type": "polygon",
|
||||
"flags": {}
|
||||
}
|
||||
labelme_data["shapes"].append(shape_data)
|
||||
vis_path = os.path.join(annotation_dir, f"{image_basename}_mask.jpg")
|
||||
cv2.imwrite(vis_path, vis_img)
|
||||
saved_files.append(vis_path)
|
||||
try:
|
||||
is_success, buffer = cv2.imencode(".jpg", orig_img)
|
||||
if is_success:
|
||||
img_bytes = io.BytesIO(buffer).getvalue()
|
||||
labelme_data["imageData"] = base64.b64encode(img_bytes).decode('utf-8')
|
||||
else:
|
||||
print("save_annotation: Failed to encode image data")
|
||||
labelme_data["imageData"] = ""
|
||||
except Exception as e:
|
||||
logger.error(f"save_annotation: Could not encode image data: {str(e)}")
|
||||
labelme_data["imageData"] = ""
|
||||
json_path = os.path.join(annotation_dir, f"{image_basename}.json")
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(labelme_data, f, indent=2)
|
||||
saved_files.append(json_path)
|
||||
logger.info(f"save_annotation: Annotation saved to {annotation_dir}")
|
||||
|
||||
# 将其打包
|
||||
zip_filename = "download.zip"
|
||||
zip_filepath = Path(zip_filename)
|
||||
dir_path = os.path.dirname(annotation_dir)
|
||||
# 创建 ZIP 文件
|
||||
try:
|
||||
with zipfile.ZipFile(zip_filepath, 'w') as zip_file:
|
||||
# 遍历文件夹中的所有文件
|
||||
for foldername, subfolders, filenames in os.walk(dir_path):
|
||||
for filename in filenames:
|
||||
file_path = os.path.join(foldername, filename)
|
||||
# 将文件写入 ZIP 文件
|
||||
zip_file.write(file_path, os.path.relpath(file_path, dir_path))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error creating zip file: {str(e)}")
|
||||
|
||||
# 返回 ZIP 文件作为响应
|
||||
return FileResponse(zip_filepath, media_type='application/zip', filename=zip_filename)
|
||||
|
||||
|
||||
# @app.post("/ai-station-api/items/")
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 473f843be79dcab0e8155a4eda08f1149b697457
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -129,6 +129,11 @@ def allowed_file(filename: str) -> bool:
|
|||
param = params.ModelParams()
|
||||
return any(filename.endswith(ext) for ext in param.ALLOWED_EXTENSIONS)
|
||||
|
||||
|
||||
|
||||
####################SAM###################################
|
||||
|
||||
|
||||
# data_info ={
|
||||
# "result": [
|
||||
# {
|
||||
|
|
|
@ -13,10 +13,10 @@ import cv2 as cv
|
|||
import torch.nn.functional as F
|
||||
from joblib import dump, load
|
||||
from fastapi import HTTPException
|
||||
|
||||
import pickle
|
||||
param = params.ModelParams()
|
||||
|
||||
|
||||
import cv2
|
||||
import traceback
|
||||
|
||||
|
||||
################################################### 图像类函数调用###########################################################################
|
||||
|
@ -579,6 +579,7 @@ def pred_single_tar(test_content):
|
|||
|
||||
|
||||
|
||||
|
||||
def pred_single_gas(test_content):
|
||||
gas_pred = []
|
||||
current_directory = os.getcwd()
|
||||
|
@ -587,7 +588,7 @@ def pred_single_gas(test_content):
|
|||
gas_model = load(model_path)
|
||||
pred = gas_model.predict(test_content)
|
||||
gas_pred.append(pred[0])
|
||||
result = [param.meirejie_tar_mae, param.meirejie_tar_r2,gas_pred]
|
||||
result = [param.meirejie_gas_mae, param.meirejie_gas_r2,gas_pred]
|
||||
# 创建 DataFrame
|
||||
df = pd.DataFrame(result, index=param.index, columns=param.columns)
|
||||
result = df.to_json(orient='index')
|
||||
|
@ -602,7 +603,7 @@ def pred_single_water(test_content):
|
|||
water_model = load(model_path)
|
||||
pred = water_model.predict(test_content)
|
||||
water_pred.append(pred[0])
|
||||
result = [param.meirejie_tar_mae, param.meirejie_tar_r2,water_pred]
|
||||
result = [param.meirejie_water_mae, param.meirejie_water_r2,water_pred]
|
||||
# 创建 DataFrame
|
||||
df = pd.DataFrame(result, index=param.index, columns=param.columns)
|
||||
result = df.to_json(orient='index')
|
||||
|
@ -616,7 +617,7 @@ def pred_single_char(test_content):
|
|||
char_model = load(model_path)
|
||||
pred = char_model.predict(test_content)
|
||||
char_pred.append(pred[0])
|
||||
result = [param.meirejie_tar_mae, param.meirejie_tar_r2,char_pred]
|
||||
result = [param.meirejie_char_mae, param.meirejie_char_r2,char_pred]
|
||||
# 创建 DataFrame
|
||||
df = pd.DataFrame(result, index=param.index, columns=param.columns)
|
||||
result = df.to_json(orient='index')
|
||||
|
@ -624,50 +625,65 @@ def pred_single_char(test_content):
|
|||
|
||||
|
||||
|
||||
def choose_model(name,data):
|
||||
def choose_model_meirejie(name,data):
|
||||
current_directory = os.getcwd()
|
||||
model_path = os.path.join(current_directory,'meirejie',param.meirejie_model_dict[name])
|
||||
model = load(model_path)
|
||||
pred = model.predict(data)
|
||||
return pred
|
||||
|
||||
|
||||
def get_excel_tar(model_name):
|
||||
data_name = param.meirejie_test_data['tar']
|
||||
current_directory = os.getcwd()
|
||||
data_path = os.path.join(current_directory,"tmp","meirejie",data_name)
|
||||
test_data = pd.read_csv(data_path)
|
||||
pred = choose_model(model_name,test_data)
|
||||
def get_excel_tar(model_name,file_path):
|
||||
if not file_path.endswith('.csv'):
|
||||
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
|
||||
test_data = pd.read_csv(file_path)
|
||||
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Tar"]
|
||||
if list(test_data.columns) != expected_columns:
|
||||
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
|
||||
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
del test_data['Tar']
|
||||
pred = choose_model_meirejie(model_name,test_data)
|
||||
test_data['tar_pred'] = pred
|
||||
return test_data.to_json(orient='records', lines=True)
|
||||
return {"status":True, "reason":test_data.to_dict(orient='records')}
|
||||
|
||||
|
||||
def get_excel_gas(model_name):
|
||||
data_name = param.meirejie_test_data['gas']
|
||||
current_directory = os.getcwd()
|
||||
data_path = os.path.join(current_directory,"tmp","meirejie",data_name)
|
||||
test_data = pd.read_csv(data_path)
|
||||
pred = choose_model(model_name,test_data)
|
||||
def get_excel_gas(model_name,file_path):
|
||||
if not file_path.endswith('.csv'):
|
||||
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
|
||||
test_data = pd.read_csv(file_path)
|
||||
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Gas"]
|
||||
if list(test_data.columns) != expected_columns:
|
||||
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
|
||||
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
del test_data['Gas']
|
||||
pred = choose_model_meirejie(model_name,test_data)
|
||||
test_data['gas_pred'] = pred
|
||||
return test_data.to_json(orient='records', lines=True)
|
||||
return {"status":True, "reason":test_data.to_dict(orient='records')}
|
||||
|
||||
def get_excel_char(model_name):
|
||||
data_name = param.meirejie_test_data['char']
|
||||
current_directory = os.getcwd()
|
||||
data_path = os.path.join(current_directory,"tmp","meirejie",data_name)
|
||||
test_data = pd.read_csv(data_path)
|
||||
pred = choose_model(model_name,test_data)
|
||||
def get_excel_char(model_name,file_path):
|
||||
if not file_path.endswith('.csv'):
|
||||
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
|
||||
test_data = pd.read_csv(file_path)
|
||||
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Char"]
|
||||
if list(test_data.columns) != expected_columns:
|
||||
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
|
||||
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
del test_data['Char']
|
||||
pred = choose_model_meirejie(model_name,test_data)
|
||||
test_data['char_pred'] = pred
|
||||
return test_data.to_json(orient='records', lines=True)
|
||||
return {"status":True, "reason":test_data.to_dict(orient='records')}
|
||||
|
||||
def get_excel_water(model_name):
|
||||
data_name = param.meirejie_test_data['water']
|
||||
current_directory = os.getcwd()
|
||||
data_path = os.path.join(current_directory,"tmp","meirejie",data_name)
|
||||
test_data = pd.read_csv(data_path)
|
||||
pred = choose_model(model_name,test_data)
|
||||
def get_excel_water(model_name,file_path):
|
||||
if not file_path.endswith('.csv'):
|
||||
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
|
||||
test_data = pd.read_csv(file_path)
|
||||
expected_columns = ["A", "V", "FC", "C", "H", "N" ,"S" ,"O" ,"H/C" ,"O/C" ,"N/C" ,"Rt" ,"Hr", "dp","T","Water"]
|
||||
if list(test_data.columns) != expected_columns:
|
||||
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
|
||||
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
del test_data['Water']
|
||||
pred = choose_model_meirejie(model_name,test_data)
|
||||
test_data['water_pred'] = pred
|
||||
return test_data.to_json(orient='records', lines=True)
|
||||
return {"status":True, "reason":test_data.to_dict(orient='records')}
|
||||
|
||||
|
||||
|
||||
|
@ -741,52 +757,59 @@ def choose_model_meijitancailiao(name,data):
|
|||
|
||||
def get_excel_ssa(model_name,file_path):
|
||||
if not file_path.endswith('.csv'):
|
||||
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
|
||||
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
|
||||
# raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
|
||||
test_data = pd.read_csv(file_path)
|
||||
expected_columns = ["A", "VM", "K/C", "MM", "AT", "At", "Rt", "SSA"]
|
||||
if list(test_data.columns) != expected_columns:
|
||||
raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
|
||||
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
del test_data['SSA']
|
||||
pred = choose_model_meijitancailiao(model_name,test_data)
|
||||
test_data['ssa_pred'] = pred
|
||||
return test_data.to_json(orient='records')
|
||||
return {"status":True, "reason":test_data.to_dict(orient='records')}
|
||||
|
||||
def get_excel_tpv(model_name,file_path):
|
||||
if not file_path.endswith('.csv'):
|
||||
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
|
||||
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
|
||||
# raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
|
||||
test_data = pd.read_csv(file_path)
|
||||
expected_columns = ["A", "VM", "K/C", "MM", "AT", "At", "Rt", "TPV"]
|
||||
if list(test_data.columns) != expected_columns:
|
||||
raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
|
||||
del test_data['TPV']
|
||||
pred = choose_model_meijitancailiao(model_name,test_data)
|
||||
test_data['tpv_pred'] = pred
|
||||
return test_data.to_json(orient='records')
|
||||
return {"status":True, "reason":test_data.to_dict(orient='records')}
|
||||
|
||||
|
||||
def get_excel_meitan(model_name,file_path):
|
||||
if not file_path.endswith('.csv'):
|
||||
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
|
||||
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
|
||||
# raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
|
||||
test_data = pd.read_csv(file_path)
|
||||
expected_columns = ["SSA", "TPV", "N", "O", "ID/IG", "J","C"]
|
||||
if list(test_data.columns) != expected_columns:
|
||||
raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
# raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
|
||||
del test_data['C']
|
||||
pred = choose_model_meijitancailiao(model_name,test_data)
|
||||
test_data['C_pred'] = pred
|
||||
return test_data.to_json(orient='records', lines=True)
|
||||
# return test_data.to_dict(orient='records')
|
||||
return {"status":True, "reason":test_data.to_dict(orient='records')}
|
||||
|
||||
def get_excel_meiliqing(model_name,file_path):
|
||||
if not file_path.endswith('.csv'):
|
||||
raise HTTPException(status_code=400, detail="文件类型有误,必须是 CSV 文件")
|
||||
return {"status":False, "reason":"上传文件类型有误,必须是 CSV 文件"}
|
||||
test_data = pd.read_csv(file_path)
|
||||
expected_columns = ["SSA", "TPV", "N", "O", "ID/IG", "J", "C"]
|
||||
if list(test_data.columns) != expected_columns:
|
||||
raise HTTPException(status_code=400, detail=f"文件列名不匹配,预期列名为: {expected_columns}")
|
||||
return {"status":False, "reason":f"文件列名不匹配,预期列名为: {expected_columns}"}
|
||||
del test_data['C']
|
||||
pred = choose_model_meijitancailiao(model_name,test_data)
|
||||
test_data['C_pred'] = pred
|
||||
return test_data.to_json(orient='records', lines=True)
|
||||
return {"status":True, "reason":test_data.to_dict(orient='records')}
|
||||
|
||||
|
||||
def pred_func(func_name, pred_data):
|
||||
|
@ -834,4 +857,7 @@ def get_pic_path(url):
|
|||
|
||||
# 3. 添加本地根目录
|
||||
local_path = f"/root/app{relative_path}"
|
||||
return local_path
|
||||
return local_path
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -140,9 +140,18 @@ class ModelParams():
|
|||
}
|
||||
|
||||
index = ['mae', 'r2', 'result']
|
||||
columns = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
|
||||
'ElasticNet Regression', 'K-Nearest Neighbors', 'Support Vector Regression',
|
||||
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
|
||||
columns = [
|
||||
'XGBoost(极端梯度提升)',
|
||||
'Linear Regression(线性回归)',
|
||||
'Ridge Regression(岭回归)',
|
||||
'Gaussian Process Regression(高斯过程回归)',
|
||||
'ElasticNet Regression(弹性网回归)',
|
||||
'K-Nearest Neighbors(K最近邻)',
|
||||
'Support Vector Regression(支持向量回归)',
|
||||
'Decision Tree Regression(决策树回归)',
|
||||
'Random Forest Regression(随机森林回归)',
|
||||
'AdaBoost Regression(AdaBoost回归)'
|
||||
]
|
||||
meirejie_model_list_gas = ['xgb_gas','lr_gas','ridge_gas','gp_gas','en_gas','kn_gas','svr_gas','dtr_gas','rfr_gas','adb_gas']
|
||||
meirejie_model_list_char = ['xgb_char','lr_char','ridge_char','gp_char','en_char','kn_char','svr_char','dtr_char','rfr_char','adb_char']
|
||||
meirejie_model_list_water = ['xgb_water','lr_water','ridge_water','gp_water','en_water','kn_water','svr_water','dtr_water','rfr_water','adb_water']
|
||||
|
@ -224,29 +233,65 @@ class ModelParams():
|
|||
"xgb_meiliqing":"model/meiliqing_XGB.joblib",
|
||||
|
||||
}
|
||||
columns_ssa = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
|
||||
'ElasticNet Regression', 'K-Nearest Neighbors', 'Support Vector Regression',
|
||||
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
|
||||
columns_ssa = [
|
||||
'XGBoost(极端梯度提升)',
|
||||
'Linear Regression(线性回归)',
|
||||
'Ridge Regression(岭回归)',
|
||||
'Gaussian Process Regression(高斯过程回归)',
|
||||
'ElasticNet Regression(弹性网回归)',
|
||||
'K-Nearest Neighbors(最近邻居算法)',
|
||||
'Support Vector Regression(支持向量回归)',
|
||||
'Decision Tree Regression(决策树回归)',
|
||||
'Random Forest Regression(随机森林回归)',
|
||||
'AdaBoost Regression(自适应提升回归)'
|
||||
]
|
||||
meijitancailiao_model_list_ssa = ['xgb_ssa','lr_ssa','ridge_ssa','gp_ssa','en_ssa','kn_ssa','svr_ssa','dtr_ssa','rfr_ssa','adb_ssa']
|
||||
meijitancailiao_ssa_mae = [258, 407,408 ,282 ,411 ,389, 405, 288,193, 330]
|
||||
meijitancailiao_ssa_r2 = [0.92,0.82,0.82,0.89,0.81,0.82,0.87,0.88,0.95,0.88]
|
||||
columns_tpv = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
|
||||
'ElasticNet Regression', 'Gradient Boosting Regression', 'Support Vector Regression',
|
||||
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
|
||||
columns_tpv = [
|
||||
'XGBoost(极端梯度提升)',
|
||||
'Linear Regression(线性回归)',
|
||||
'Ridge Regression(岭回归)',
|
||||
'Gaussian Process Regression(高斯过程回归)',
|
||||
'ElasticNet Regression(弹性网回归)',
|
||||
'Gradient Boosting Regression(梯度提升回归)',
|
||||
'Support Vector Regression(支持向量回归)',
|
||||
'Decision Tree Regression(决策树回归)',
|
||||
'Random Forest Regression(随机森林回归)',
|
||||
'AdaBoost Regression(自适应提升回归)'
|
||||
]
|
||||
meijitancailiao_model_list_tpv = ['xgb_tpv', 'lr_tpv', 'ridge_tpv', 'gp_tpv', 'en_tpv', 'gdbt_tpv', 'svr_tpv', 'dtr_tpv', 'rfr_tpv', 'adb_tpv']
|
||||
meijitancailiao_tpv_mae = [0.2, 0.2, 0.2, 0.2, 0.2, 0.23, 0.23, 0.21, 0.16, 0.21]
|
||||
meijitancailiao_tpv_r2 = [0.81, 0.81, 0.81, 0.8, 0.82, 0.80, 0.78, 0.73, 0.85, 0.84]
|
||||
|
||||
columns_meitan = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
|
||||
'ElasticNet Regression', 'Gradient Boosting Regression', 'Support Vector Regression',
|
||||
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
|
||||
columns_meitan = [
|
||||
'XGBoost(极端梯度提升)',
|
||||
'Linear Regression(线性回归)',
|
||||
'Ridge Regression(岭回归)',
|
||||
'Gaussian Process Regression(高斯过程回归)',
|
||||
'ElasticNet Regression(弹性网回归)',
|
||||
'Gradient Boosting Regression(梯度提升回归)',
|
||||
'Support Vector Regression(支持向量回归)',
|
||||
'Decision Tree Regression(决策树回归)',
|
||||
'Random Forest Regression(随机森林回归)',
|
||||
'AdaBoost Regression(自适应提升回归)'
|
||||
]
|
||||
meijitancailiao_model_list_meitan = ['xgb_meitan', 'lr_meitan', 'ridge_meitan', 'gp_meitan', 'en_meitan', 'gdbt_meitan', 'svr_meitan', 'dtr_meitan', 'rfr_meitan', 'adb_meitan']
|
||||
meijitancailiao_meitan_mae = [8.17, 37.61, 37.66, 13.41, 20.96, 8.03, 14.89, 19.48, 12.53, 15.6]
|
||||
meijitancailiao_meitan_r2 = [0.96, 0.19, 0.19, 0.91, 0.8, 0.96, 0.88, 0.86, 0.91, 0.91]
|
||||
|
||||
columns_meiliqing = ['XGBoost', 'Linear Regression', 'Ridge Regression', 'Gaussian Process Regression',
|
||||
'ElasticNet Regression', 'Gradient Boosting Regression', 'Support Vector Regression',
|
||||
'Decision Tree Regression', 'Random Forest Regression', 'AdaBoost Regression']
|
||||
columns_meiliqing = [
|
||||
'XGBoost(极端梯度提升)',
|
||||
'Linear Regression(线性回归)',
|
||||
'Ridge Regression(岭回归)',
|
||||
'Gaussian Process Regression(高斯过程回归)',
|
||||
'ElasticNet Regression(弹性网回归)',
|
||||
'Gradient Boosting Regression(梯度提升回归)',
|
||||
'Support Vector Regression(支持向量回归)',
|
||||
'Decision Tree Regression(决策树回归)',
|
||||
'Random Forest Regression(随机森林回归)',
|
||||
'AdaBoost Regression(自适应提升回归)'
|
||||
]
|
||||
meijitancailiao_model_list_meiliqing = ['xgb_meiliqing', 'lr_meiliqing', 'ridge_meiliqing', 'gp_meiliqing', 'en_meiliqing', 'gdbt_meiliqing', 'svr_meiliqing', 'dtr_meiliqing', 'rfr_meiliqing', 'adb_meiliqing']
|
||||
meijitancailiao_meiliqing_mae = [8.38, 35.02, 35.1, 11.02, 13.58, 7.04, 13.13, 13.13, 11.25, 9.99]
|
||||
meijitancailiao_meiliqing_r2 = [0.95, 0.33, 0.33, 0.94, 0.91, 0.97, 0.88, 0.89, 0.92, 0.94]
|
||||
|
@ -276,4 +321,12 @@ class ModelParams():
|
|||
}
|
||||
|
||||
ALLOWED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.tif'}
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB
|
||||
MAX_FILE_SAM_SIZE = 10 * 1024 * 1024
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
DEFAULT_MODEL_PATH = r"/home/xiazj/ai-station-code/segment_anything_model/weights/vit_b.pth"
|
||||
|
|
@ -21,6 +21,26 @@ class Meitan(BaseModel):
|
|||
J: float
|
||||
|
||||
|
||||
class Meirejie(BaseModel):
|
||||
A: float
|
||||
V: float
|
||||
FC: float
|
||||
C: float
|
||||
H: float
|
||||
N: float
|
||||
S: float
|
||||
O: float
|
||||
H_C: float
|
||||
O_C: float
|
||||
N_C: float
|
||||
Rt: float
|
||||
Hr: float
|
||||
dp: float
|
||||
T: float
|
||||
|
||||
|
||||
|
||||
|
||||
class FormData(BaseModel):
|
||||
A_min: Optional[float] = None
|
||||
A_max: Optional[float] = None
|
||||
|
|
|
@ -0,0 +1,398 @@
|
|||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
import shutil
|
||||
from segment_anything_model.segment_anything import sam_model_registry, SamPredictor
|
||||
import random
|
||||
import traceback
|
||||
import base64
|
||||
import io
|
||||
from work_util.logger import logger
|
||||
import pickle
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
def get_display_image_ori(loaded_data):
|
||||
if loaded_data['image_rgb'] is None:
|
||||
logger.info("get_display_image: No image loaded")
|
||||
return {"status": False,"reason":"无图片加载"}
|
||||
display_image = loaded_data['image_rgb'].copy()
|
||||
try:
|
||||
for point, label in zip(loaded_data['input_point'], loaded_data['input_label']):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(display_image, tuple(point), 5, color, -1)
|
||||
if loaded_data['selected_mask'] is not None:
|
||||
class_name = loaded_data['class_names'][loaded_data['class_index']]
|
||||
color = loaded_data['class_colors'].get(class_name, (0, 0, 128))
|
||||
display_image = sam_apply_mask(display_image, loaded_data['selected_mask'], color)
|
||||
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
|
||||
return {"status" : True, "reason":display_image}
|
||||
except Exception as e:
|
||||
logger.info(f"get_display_image: Error processing image: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return {"status": False,"reason":f"get_display_image: Error processing image: {str(e)}"}
|
||||
|
||||
def get_all_classes(data):
|
||||
return {
|
||||
"classes": data['class_names'],
|
||||
"colors": {name: color for name, color in data['class_colors'].items()}
|
||||
}
|
||||
|
||||
def get_classes(data):
|
||||
return get_all_classes(data)
|
||||
|
||||
|
||||
def reset_annotation(loaded_data):
|
||||
loaded_data['input_point'] = []
|
||||
loaded_data['input_label'] = []
|
||||
loaded_data['selected_mask'] = None
|
||||
loaded_data['logit_input'] = None
|
||||
loaded_data['masks'] = {}
|
||||
logger.info("已重置标注状态")
|
||||
return loaded_data
|
||||
|
||||
|
||||
def remove_class(data,api,class_name):
|
||||
if class_name in api['class_annotations']:
|
||||
del api['class_annotations'][class_name]
|
||||
|
||||
if class_name == data['class_names'][data['class_index']]:
|
||||
data['class_index'] = 0
|
||||
data['class_names'].remove(class_name)
|
||||
if class_name in data['class_colors']:
|
||||
del data['class_colors'][class_name]
|
||||
if class_name in data['masks']:
|
||||
del data['masks'][class_name]
|
||||
return data,api
|
||||
|
||||
|
||||
def reset_annotation_all(data,api):
|
||||
data = reset_annotation(data)
|
||||
for class_data in api['class_annotations'].values():
|
||||
class_data['points'] = []
|
||||
class_data['point_types'] = []
|
||||
class_data['masks'] = []
|
||||
class_data['scores'] = []
|
||||
class_data['selected_mask_index'] = -1
|
||||
if 'selected_mask' in class_data:
|
||||
del class_data['selected_mask']
|
||||
return data,api
|
||||
|
||||
|
||||
def reset_class_annotations(data,api):
|
||||
"""重置class_annotations,保留类别但清除掩码和点"""
|
||||
classes = get_classes(data).get('classes', [])
|
||||
new_annotations = {}
|
||||
for class_name in classes:
|
||||
new_annotations[class_name] = {
|
||||
'points': [],
|
||||
'point_types': [],
|
||||
'masks': [],
|
||||
'selected_mask_index': -1
|
||||
}
|
||||
api['class_annotations'] = new_annotations
|
||||
logger.info("已重置class_annotations,保留类别但清除掩码和点")
|
||||
return data,api
|
||||
|
||||
|
||||
def add_class(data,class_name, color=None):
|
||||
if class_name in data['class_names']:
|
||||
logger.info(f"类别 '{class_name}' 已存在")
|
||||
return {'status':False, 'reason':f"类别 '{class_name}' 已存在"}
|
||||
data['class_names'].append(class_name)
|
||||
if color is None:
|
||||
color = tuple(np.random.randint(100, 256, 3).tolist())
|
||||
r, g, b = [int(c) for c in color]
|
||||
bgr_color = (b, g, r)
|
||||
data['class_colors'][class_name] = tuple(bgr_color)
|
||||
logger.info(f"已添加类别: {class_name}, 颜色: {tuple(color)}")
|
||||
return {'status':True, 'reason':data}
|
||||
|
||||
|
||||
def set_current_class(data, api, class_index, color=None):
|
||||
classes = get_classes(data)
|
||||
if 'classes' in classes and class_index < len(classes['classes']):
|
||||
class_name = classes['classes'][class_index]
|
||||
api['current_class'] = class_name
|
||||
if class_name not in api['class_annotations']:
|
||||
api['class_annotations'][class_name] = {
|
||||
'points': [],
|
||||
'point_types': [],
|
||||
'masks': [],
|
||||
'selected_mask_index': -1
|
||||
}
|
||||
color = data['class_colors'][class_name]
|
||||
if color:
|
||||
api['class_colors'][class_name] = color
|
||||
elif class_name not in api['class_colors']:
|
||||
predefined_colors = [
|
||||
(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
||||
(255, 255, 0), (255, 0, 255), (0, 255, 255)
|
||||
]
|
||||
color_index = len(api['class_colors']) % len(predefined_colors)
|
||||
api['class_colors'][class_name] = predefined_colors[color_index]
|
||||
return class_name,api
|
||||
return None,api
|
||||
|
||||
|
||||
def add_point(data, x, y, is_foreground=True):
|
||||
data['input_point'].append([x, y])
|
||||
data['input_label'].append(1 if is_foreground else 0)
|
||||
logger.info(f"添加{'前景' if is_foreground else '背景'}点: ({x}, {y})")
|
||||
return data
|
||||
|
||||
|
||||
|
||||
def load_model(path):
|
||||
# 加载配置内容
|
||||
config_path = os.path.join(path,'model_params.pickle')
|
||||
with open(config_path, 'rb') as file:
|
||||
loaded_data,api_config = pickle.load(file)
|
||||
return loaded_data,api_config
|
||||
|
||||
def save_model(loaded_data,api_config,path):
|
||||
config_path = os.path.join(path,'model_params.pickle')
|
||||
save_data = (loaded_data,api_config)
|
||||
with open(config_path, 'wb') as file:
|
||||
pickle.dump(save_data, file)
|
||||
|
||||
|
||||
def sam_apply_mask(image, mask, color, alpha=0.5):
|
||||
masked_image = image.copy()
|
||||
for c in range(3):
|
||||
masked_image[:, :, c] = np.where(
|
||||
mask == 1,
|
||||
image[:, :, c] * (1 - alpha) + alpha * color[c],
|
||||
image[:, :, c]
|
||||
)
|
||||
return masked_image
|
||||
|
||||
|
||||
def apply_mask_overlay(image, mask, color, alpha=0.5):
|
||||
colored_mask = np.zeros_like(image)
|
||||
colored_mask[mask > 0] = color
|
||||
return cv2.addWeighted(image, 1, colored_mask, alpha, 0)
|
||||
|
||||
def load_tmp_image(path):
|
||||
with open(path, "rb") as image_file:
|
||||
# 将图片文件读取为二进制数据并进行 Base64 编码
|
||||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return encoded_string
|
||||
|
||||
def refresh_image(data,api,path):
|
||||
result = get_image_display(data,api)
|
||||
if result['status'] == False:
|
||||
return {
|
||||
"status" : False,
|
||||
"reason" : result['reason']
|
||||
}
|
||||
else:
|
||||
img = result['reason']
|
||||
display_img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
pil_img = Image.fromarray(display_img_rgb)
|
||||
tmp_path = os.path.join(path,'temp/output_image.jpg')
|
||||
pil_img.save(tmp_path)
|
||||
return {
|
||||
"status" : True,
|
||||
"reason" : tmp_path
|
||||
}
|
||||
|
||||
|
||||
def set_class_color(data,api,index,color):
|
||||
class_name = data['class_names'][index]
|
||||
api['class_colors'][class_name] = color
|
||||
data['class_colors'][class_name] = color
|
||||
return data,api
|
||||
|
||||
|
||||
|
||||
def add_to_class(api,class_name):
|
||||
if class_name is None:
|
||||
class_name = api['current_class']
|
||||
if not class_name or class_name not in api['class_annotations']:
|
||||
return {
|
||||
'status':False,
|
||||
'reason': "该分类没有标注点"
|
||||
}
|
||||
class_data = api['class_annotations'][class_name]
|
||||
if 'selected_mask' not in class_data or class_data['selected_mask'] is None:
|
||||
return {
|
||||
'status':False,
|
||||
'reason': "该分类没有进行预测"
|
||||
}
|
||||
class_data['final_mask'] = class_data['selected_mask'].copy()
|
||||
class_data['points'] = []
|
||||
class_data['point_types'] = []
|
||||
class_data['masks'] = []
|
||||
class_data['scores'] = []
|
||||
class_data['selected_mask_index'] = -1
|
||||
return {
|
||||
'status':True,
|
||||
'reason': api
|
||||
}
|
||||
|
||||
|
||||
|
||||
def get_image_display(data,api):
|
||||
|
||||
if data['image'] is None:
|
||||
logger.info("get_display_image: No image loaded")
|
||||
return {"status":False, "reason":"获取图像:没有图像加载"}
|
||||
display_image = data['image'].copy()
|
||||
try:
|
||||
for point, label in zip(data['input_point'], data['input_label']):
|
||||
color = (0, 255, 0) if label == 1 else (0, 0, 255)
|
||||
cv2.circle(display_image, tuple(point), 5, color, -1)
|
||||
if data['selected_mask'] is not None:
|
||||
class_name = data['class_names'][data['class_index']]
|
||||
color = data['class_colors'].get(class_name, (0, 0, 128))
|
||||
display_image = sam_apply_mask(display_image, data['selected_mask'], color)
|
||||
logger.info(f"get_display_image: Returning image with shape {display_image.shape}")
|
||||
img = display_image
|
||||
if not isinstance(img, np.ndarray) or img.size == 0:
|
||||
logger.info(f"get_image_display: Invalid image array, shape: {img.shape if isinstance(img, np.ndarray) else 'None'}")
|
||||
return {"status":False, "reason":f"get_image_display: Invalid image array, shape: {display_image.shape if isinstance(display_image, np.ndarray) else 'None'}"}
|
||||
# 仅应用当前图片的final_mask
|
||||
for class_name, class_data in api['class_annotations'].items():
|
||||
if 'final_mask' in class_data and class_data['final_mask'] is not None:
|
||||
color = api['class_colors'].get(class_name, (0, 255, 0))
|
||||
mask = class_data['final_mask']
|
||||
if isinstance(mask, list):
|
||||
mask = np.array(mask, dtype=np.uint8)
|
||||
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
|
||||
img = apply_mask_overlay(img, mask, color, alpha=0.5)
|
||||
elif 'selected_mask' in class_data and class_data['selected_mask'] is not None:
|
||||
color = api['class_colors'].get(class_name, (0, 255, 0))
|
||||
mask = class_data['selected_mask']
|
||||
if isinstance(mask, list):
|
||||
mask = np.array(mask, dtype=np.uint8)
|
||||
logger.info(f"Applying mask for class {class_name}, shape: {mask.shape}")
|
||||
img = apply_mask_overlay(img, mask, color, alpha=0.5)
|
||||
if api['current_class'] and api['current_class'] in api['class_annotations']:
|
||||
class_data = api['class_annotations'][api['current_class']]
|
||||
for i, (x, y) in enumerate(class_data['points']):
|
||||
is_fg = class_data['point_types'][i]
|
||||
color = (0, 255, 0) if is_fg else (0, 0, 255)
|
||||
print(f"Drawing point at ({x}, {y}), type: {'foreground' if is_fg else 'background'}")
|
||||
cv2.circle(img, (int(x), int(y)), 5, color, -1)
|
||||
logger.info(f"get_image_display: Returning image with shape {img.shape}")
|
||||
return {"status":True, "reason":img}
|
||||
except Exception as e:
|
||||
logger.error(f"get_display_image: Error processing image: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return {"status":False, "reason":f"get_display_image: Error processing image: {str(e)}"}
|
||||
|
||||
|
||||
def add_annotation_point(api, x, y, is_foreground=True):
|
||||
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
|
||||
return {
|
||||
"status": False,
|
||||
"reason": "请选择或新建分类"
|
||||
}
|
||||
class_data = api['class_annotations'][api['current_class']]
|
||||
|
||||
class_data['points'].append((x, y))
|
||||
class_data['point_types'].append(is_foreground)
|
||||
return {
|
||||
'status' : True,
|
||||
'reason' : {
|
||||
'points': class_data['points'],
|
||||
'types': class_data['point_types']
|
||||
},
|
||||
'api':api
|
||||
}
|
||||
|
||||
|
||||
|
||||
def delete_last_point(api):
|
||||
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
|
||||
return {
|
||||
"status": False,
|
||||
"reason": "当前类没有点需要删除"
|
||||
}
|
||||
class_data = api['class_annotations'][api['current_class']]
|
||||
if not class_data['points']:
|
||||
return {
|
||||
"status": False,
|
||||
"reason": "当前类没有点需要删除"
|
||||
}
|
||||
class_data['points'].pop()
|
||||
class_data['point_types'].pop()
|
||||
return {
|
||||
"status": True,
|
||||
"reason": api
|
||||
}
|
||||
|
||||
|
||||
|
||||
def reset_current_class_points(api):
|
||||
if not api['current_class'] or api['current_class'] not in api['class_annotations']:
|
||||
return {
|
||||
"status": False,
|
||||
"reason": "当前类没有点需要删除"
|
||||
}
|
||||
class_data = api['class_annotations'][api['current_class']]
|
||||
class_data['points'] = []
|
||||
class_data['point_types'] = []
|
||||
return {
|
||||
"status": True,
|
||||
"reason": api
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def predict_mask(data,sam_predictor):
|
||||
if data['image_rgb'] is None:
|
||||
logger.info("predict_mask: No image loaded")
|
||||
return {"status":False, "reason":"预测掩码:没有图像加载"}
|
||||
if len(data['input_point']) == 0:
|
||||
logger.info("predict_mask: No points added")
|
||||
return {"status":False, "reason":"预测掩码:没有进行点标注"}
|
||||
try:
|
||||
sam_predictor.set_image(data['image_rgb'])
|
||||
except Exception as e:
|
||||
logger.error(f"predict_mask: Error setting image: {str(e)}")
|
||||
return {"status":False, "reason":f"predict_mask: Error setting image: {str(e)}"}
|
||||
input_point_np = np.array(data['input_point'])
|
||||
input_label_np = np.array(data['input_label'])
|
||||
try:
|
||||
masks_pred, scores, logits = sam_predictor.predict(
|
||||
point_coords=input_point_np,
|
||||
point_labels=input_label_np,
|
||||
mask_input=data['logit_input'][None, :, :] if data['logit_input'] is not None else None,
|
||||
multimask_output=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"predict_mask: Error during prediction: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return {"status":False, "reason":f"predict_mask: Error during prediction: {str(e)}"}
|
||||
data['masks_pred'] = masks_pred
|
||||
data['scores'] = scores
|
||||
data['logits'] = logits
|
||||
best_mask_idx = np.argmax(scores)
|
||||
data['selected_mask'] = masks_pred[best_mask_idx]
|
||||
data['logit_input'] = logits[best_mask_idx, :, :]
|
||||
logger.info(f"predict_mask: Predicted {len(masks_pred)} masks, best score: {scores[best_mask_idx]:.4f}")
|
||||
return {
|
||||
"status":True,
|
||||
"reason":{
|
||||
"masks": [mask.tolist() for mask in masks_pred],
|
||||
"scores": scores.tolist(),
|
||||
"selected_index": int(best_mask_idx),
|
||||
"data":data
|
||||
}}
|
||||
|
||||
|
||||
def get_image_info(data):
|
||||
|
||||
return {
|
||||
"filename": data['filename'],
|
||||
"index": data['current_index'],
|
||||
"total": len(data['image_files']),
|
||||
"width": data['image'].shape[1],
|
||||
"height": data['image'].shape[0]
|
||||
}
|
||||
|
Loading…
Reference in New Issue