154 lines
4.8 KiB
Python
154 lines
4.8 KiB
Python
|
import numpy as np
|
|||
|
import sys
|
|||
|
import os
|
|||
|
|
|||
|
# 尝试多种方式导入GDAL
|
|||
|
try:
|
|||
|
# 尝试直接导入GDAL
|
|||
|
import gdal
|
|||
|
import ogr
|
|||
|
import osr
|
|||
|
|
|||
|
print("成功导入GDAL库")
|
|||
|
except ImportError:
|
|||
|
try:
|
|||
|
# 尝试从osgeo导入
|
|||
|
from osgeo import gdal, ogr, osr
|
|||
|
|
|||
|
print("从osgeo成功导入GDAL库")
|
|||
|
except ImportError:
|
|||
|
print("无法导入GDAL库,请确保已安装GDAL及其Python绑定")
|
|||
|
sys.exit(1)
|
|||
|
|
|||
|
|
|||
|
def process_rasters(shp_path, raster_path, output_path):
|
|||
|
"""
|
|||
|
使用shp文件裁剪栅格,保留不重叠部分设为value=8,然后与原栅格合并
|
|||
|
|
|||
|
参数:
|
|||
|
shp_path: shapefile的路径
|
|||
|
raster_path: 原始栅格文件路径 (value 1-7)
|
|||
|
output_path: 输出栅格文件路径
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 设置GDAL错误处理
|
|||
|
gdal.UseExceptions()
|
|||
|
|
|||
|
# 打开栅格文件
|
|||
|
raster = gdal.Open(raster_path)
|
|||
|
if raster is None:
|
|||
|
print(f"无法打开栅格文件: {raster_path}")
|
|||
|
return None
|
|||
|
|
|||
|
# 获取栅格信息
|
|||
|
print(f"栅格大小: {raster.RasterXSize} x {raster.RasterYSize}, {raster.RasterCount} 波段")
|
|||
|
geo_transform = raster.GetGeoTransform()
|
|||
|
projection = raster.GetProjection()
|
|||
|
print(f"栅格投影: {projection}")
|
|||
|
|
|||
|
# 读取栅格数据
|
|||
|
band = raster.GetRasterBand(1)
|
|||
|
raster_data = band.ReadAsArray()
|
|||
|
nodata = band.GetNoDataValue()
|
|||
|
if nodata is None:
|
|||
|
nodata = 0
|
|||
|
|
|||
|
# 打开shapefile
|
|||
|
driver = ogr.GetDriverByName("ESRI Shapefile")
|
|||
|
try:
|
|||
|
vector = driver.Open(shp_path, 0) # 0表示只读模式
|
|||
|
except Exception as e:
|
|||
|
print(f"无法打开shapefile: {e}")
|
|||
|
return None
|
|||
|
|
|||
|
if vector is None:
|
|||
|
print(f"无法打开shapefile: {shp_path}")
|
|||
|
return None
|
|||
|
|
|||
|
# 获取shapefile信息
|
|||
|
layer = vector.GetLayer()
|
|||
|
feature_count = layer.GetFeatureCount()
|
|||
|
print(f"Shapefile有 {feature_count} 个要素")
|
|||
|
|
|||
|
# 创建栅格掩码
|
|||
|
# 使用GDAL内置的栅格化功能将shapefile转换为栅格
|
|||
|
print("将shapefile栅格化...")
|
|||
|
memory_driver = gdal.GetDriverByName('MEM')
|
|||
|
mask_ds = memory_driver.Create('', raster.RasterXSize, raster.RasterYSize, 1, gdal.GDT_Byte)
|
|||
|
mask_ds.SetGeoTransform(geo_transform)
|
|||
|
mask_ds.SetProjection(projection)
|
|||
|
|
|||
|
# 栅格化
|
|||
|
mask_band = mask_ds.GetRasterBand(1)
|
|||
|
mask_band.Fill(0) # 初始化为0
|
|||
|
|
|||
|
# 将shapefile栅格化到掩码上
|
|||
|
gdal.RasterizeLayer(mask_ds, [1], layer, burn_values=[1])
|
|||
|
|
|||
|
# 读取掩码数据
|
|||
|
mask_data = mask_band.ReadAsArray()
|
|||
|
|
|||
|
# 栅格有值区域的掩码 (value 1-7)
|
|||
|
valid_data_mask = (raster_data >= 1) & (raster_data <= 7)
|
|||
|
|
|||
|
# 获取被shapefile覆盖但不与原始有效栅格区域重叠的部分
|
|||
|
shp_mask = mask_data > 0 # shapefile 覆盖区域
|
|||
|
non_overlapping_mask = shp_mask & (~valid_data_mask) # shapefile 覆盖但不与原始栅格重叠区域
|
|||
|
|
|||
|
# 合并两个栅格
|
|||
|
merged_raster = raster_data.copy()
|
|||
|
merged_raster[non_overlapping_mask] = 8
|
|||
|
|
|||
|
# 创建输出栅格
|
|||
|
driver = gdal.GetDriverByName('GTiff')
|
|||
|
out_ds = driver.Create(output_path, raster.RasterXSize, raster.RasterYSize, 1, band.DataType)
|
|||
|
out_ds.SetGeoTransform(geo_transform)
|
|||
|
out_ds.SetProjection(projection)
|
|||
|
|
|||
|
# 写入数据
|
|||
|
out_band = out_ds.GetRasterBand(1)
|
|||
|
out_band.WriteArray(merged_raster)
|
|||
|
out_band.SetNoDataValue(nodata)
|
|||
|
|
|||
|
# 关闭数据集
|
|||
|
out_ds = None
|
|||
|
mask_ds = None
|
|||
|
raster = None
|
|||
|
vector = None
|
|||
|
|
|||
|
print(f"处理完成,结果保存至: {output_path}")
|
|||
|
|
|||
|
# 返回处理后的统计信息
|
|||
|
stats = {
|
|||
|
"原始栅格有效区域(value 1-7)像素数": int(np.sum(valid_data_mask)),
|
|||
|
"shapefile覆盖区域像素数": int(np.sum(shp_mask)),
|
|||
|
"不重叠区域(value=8)像素数": int(np.sum(non_overlapping_mask)),
|
|||
|
"结果栅格总有效像素数": int(np.sum(merged_raster > 0))
|
|||
|
}
|
|||
|
|
|||
|
return stats
|
|||
|
except Exception as e:
|
|||
|
print(f"发生错误: {e}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return None
|
|||
|
|
|||
|
|
|||
|
# 使用示例
|
|||
|
if __name__ == "__main__":
|
|||
|
# 输入文件路径
|
|||
|
shp_path = r"E:\Data\z18\shandongshp\sd.shp"
|
|||
|
raster_path = r"E:\Data\z18\sd\sd_y10m\sddimao_10m.tif"
|
|||
|
output_path = r"E:\Data\z18\sd\sd_10m\sddimao_10mmerged.tif"
|
|||
|
|
|||
|
# 执行处理
|
|||
|
stats = process_rasters(shp_path, raster_path, output_path)
|
|||
|
|
|||
|
# 打印统计信息
|
|||
|
if stats:
|
|||
|
for key, value in stats.items():
|
|||
|
print(f"{key}: {value}")
|
|||
|
else:
|
|||
|
print("处理失败,无统计信息可显示")
|
|||
|
|