GreenTransPowerCalculate/deeplabv3sdRenewable/tools/山东省地貌识别tools/潜力评估阶段/地貌类型矢量转栅格-cropland后处理部分.py

154 lines
4.8 KiB
Python
Raw Normal View History

2025-04-27 09:58:17 +08:00
import numpy as np
import sys
import os
# 尝试多种方式导入GDAL
try:
# 尝试直接导入GDAL
import gdal
import ogr
import osr
print("成功导入GDAL库")
except ImportError:
try:
# 尝试从osgeo导入
from osgeo import gdal, ogr, osr
print("从osgeo成功导入GDAL库")
except ImportError:
print("无法导入GDAL库请确保已安装GDAL及其Python绑定")
sys.exit(1)
def process_rasters(shp_path, raster_path, output_path):
"""
使用shp文件裁剪栅格保留不重叠部分设为value=8然后与原栅格合并
参数:
shp_path: shapefile的路径
raster_path: 原始栅格文件路径 (value 1-7)
output_path: 输出栅格文件路径
"""
try:
# 设置GDAL错误处理
gdal.UseExceptions()
# 打开栅格文件
raster = gdal.Open(raster_path)
if raster is None:
print(f"无法打开栅格文件: {raster_path}")
return None
# 获取栅格信息
print(f"栅格大小: {raster.RasterXSize} x {raster.RasterYSize}, {raster.RasterCount} 波段")
geo_transform = raster.GetGeoTransform()
projection = raster.GetProjection()
print(f"栅格投影: {projection}")
# 读取栅格数据
band = raster.GetRasterBand(1)
raster_data = band.ReadAsArray()
nodata = band.GetNoDataValue()
if nodata is None:
nodata = 0
# 打开shapefile
driver = ogr.GetDriverByName("ESRI Shapefile")
try:
vector = driver.Open(shp_path, 0) # 0表示只读模式
except Exception as e:
print(f"无法打开shapefile: {e}")
return None
if vector is None:
print(f"无法打开shapefile: {shp_path}")
return None
# 获取shapefile信息
layer = vector.GetLayer()
feature_count = layer.GetFeatureCount()
print(f"Shapefile有 {feature_count} 个要素")
# 创建栅格掩码
# 使用GDAL内置的栅格化功能将shapefile转换为栅格
print("将shapefile栅格化...")
memory_driver = gdal.GetDriverByName('MEM')
mask_ds = memory_driver.Create('', raster.RasterXSize, raster.RasterYSize, 1, gdal.GDT_Byte)
mask_ds.SetGeoTransform(geo_transform)
mask_ds.SetProjection(projection)
# 栅格化
mask_band = mask_ds.GetRasterBand(1)
mask_band.Fill(0) # 初始化为0
# 将shapefile栅格化到掩码上
gdal.RasterizeLayer(mask_ds, [1], layer, burn_values=[1])
# 读取掩码数据
mask_data = mask_band.ReadAsArray()
# 栅格有值区域的掩码 (value 1-7)
valid_data_mask = (raster_data >= 1) & (raster_data <= 7)
# 获取被shapefile覆盖但不与原始有效栅格区域重叠的部分
shp_mask = mask_data > 0 # shapefile 覆盖区域
non_overlapping_mask = shp_mask & (~valid_data_mask) # shapefile 覆盖但不与原始栅格重叠区域
# 合并两个栅格
merged_raster = raster_data.copy()
merged_raster[non_overlapping_mask] = 8
# 创建输出栅格
driver = gdal.GetDriverByName('GTiff')
out_ds = driver.Create(output_path, raster.RasterXSize, raster.RasterYSize, 1, band.DataType)
out_ds.SetGeoTransform(geo_transform)
out_ds.SetProjection(projection)
# 写入数据
out_band = out_ds.GetRasterBand(1)
out_band.WriteArray(merged_raster)
out_band.SetNoDataValue(nodata)
# 关闭数据集
out_ds = None
mask_ds = None
raster = None
vector = None
print(f"处理完成,结果保存至: {output_path}")
# 返回处理后的统计信息
stats = {
"原始栅格有效区域(value 1-7)像素数": int(np.sum(valid_data_mask)),
"shapefile覆盖区域像素数": int(np.sum(shp_mask)),
"不重叠区域(value=8)像素数": int(np.sum(non_overlapping_mask)),
"结果栅格总有效像素数": int(np.sum(merged_raster > 0))
}
return stats
except Exception as e:
print(f"发生错误: {e}")
import traceback
traceback.print_exc()
return None
# 使用示例
if __name__ == "__main__":
# 输入文件路径
shp_path = r"E:\Data\z18\shandongshp\sd.shp"
raster_path = r"E:\Data\z18\sd\sd_y10m\sddimao_10m.tif"
output_path = r"E:\Data\z18\sd\sd_10m\sddimao_10mmerged.tif"
# 执行处理
stats = process_rasters(shp_path, raster_path, output_path)
# 打印统计信息
if stats:
for key, value in stats.items():
print(f"{key}: {value}")
else:
print("处理失败,无统计信息可显示")