import numpy as np import sys import os # 尝试多种方式导入GDAL try: # 尝试直接导入GDAL import gdal import ogr import osr print("成功导入GDAL库") except ImportError: try: # 尝试从osgeo导入 from osgeo import gdal, ogr, osr print("从osgeo成功导入GDAL库") except ImportError: print("无法导入GDAL库,请确保已安装GDAL及其Python绑定") sys.exit(1) def process_rasters(shp_path, raster_path, output_path): """ 使用shp文件裁剪栅格,保留不重叠部分设为value=8,然后与原栅格合并 参数: shp_path: shapefile的路径 raster_path: 原始栅格文件路径 (value 1-7) output_path: 输出栅格文件路径 """ try: # 设置GDAL错误处理 gdal.UseExceptions() # 打开栅格文件 raster = gdal.Open(raster_path) if raster is None: print(f"无法打开栅格文件: {raster_path}") return None # 获取栅格信息 print(f"栅格大小: {raster.RasterXSize} x {raster.RasterYSize}, {raster.RasterCount} 波段") geo_transform = raster.GetGeoTransform() projection = raster.GetProjection() print(f"栅格投影: {projection}") # 读取栅格数据 band = raster.GetRasterBand(1) raster_data = band.ReadAsArray() nodata = band.GetNoDataValue() if nodata is None: nodata = 0 # 打开shapefile driver = ogr.GetDriverByName("ESRI Shapefile") try: vector = driver.Open(shp_path, 0) # 0表示只读模式 except Exception as e: print(f"无法打开shapefile: {e}") return None if vector is None: print(f"无法打开shapefile: {shp_path}") return None # 获取shapefile信息 layer = vector.GetLayer() feature_count = layer.GetFeatureCount() print(f"Shapefile有 {feature_count} 个要素") # 创建栅格掩码 # 使用GDAL内置的栅格化功能将shapefile转换为栅格 print("将shapefile栅格化...") memory_driver = gdal.GetDriverByName('MEM') mask_ds = memory_driver.Create('', raster.RasterXSize, raster.RasterYSize, 1, gdal.GDT_Byte) mask_ds.SetGeoTransform(geo_transform) mask_ds.SetProjection(projection) # 栅格化 mask_band = mask_ds.GetRasterBand(1) mask_band.Fill(0) # 初始化为0 # 将shapefile栅格化到掩码上 gdal.RasterizeLayer(mask_ds, [1], layer, burn_values=[1]) # 读取掩码数据 mask_data = mask_band.ReadAsArray() # 栅格有值区域的掩码 (value 1-7) valid_data_mask = (raster_data >= 1) & (raster_data <= 7) # 获取被shapefile覆盖但不与原始有效栅格区域重叠的部分 shp_mask = mask_data > 0 # shapefile 覆盖区域 non_overlapping_mask = shp_mask & (~valid_data_mask) # shapefile 覆盖但不与原始栅格重叠区域 # 合并两个栅格 merged_raster = raster_data.copy() merged_raster[non_overlapping_mask] = 8 # 创建输出栅格 driver = gdal.GetDriverByName('GTiff') out_ds = driver.Create(output_path, raster.RasterXSize, raster.RasterYSize, 1, band.DataType) out_ds.SetGeoTransform(geo_transform) out_ds.SetProjection(projection) # 写入数据 out_band = out_ds.GetRasterBand(1) out_band.WriteArray(merged_raster) out_band.SetNoDataValue(nodata) # 关闭数据集 out_ds = None mask_ds = None raster = None vector = None print(f"处理完成,结果保存至: {output_path}") # 返回处理后的统计信息 stats = { "原始栅格有效区域(value 1-7)像素数": int(np.sum(valid_data_mask)), "shapefile覆盖区域像素数": int(np.sum(shp_mask)), "不重叠区域(value=8)像素数": int(np.sum(non_overlapping_mask)), "结果栅格总有效像素数": int(np.sum(merged_raster > 0)) } return stats except Exception as e: print(f"发生错误: {e}") import traceback traceback.print_exc() return None # 使用示例 if __name__ == "__main__": # 输入文件路径 shp_path = r"E:\Data\z18\shandongshp\sd.shp" raster_path = r"E:\Data\z18\sd\sd_y10m\sddimao_10m.tif" output_path = r"E:\Data\z18\sd\sd_10m\sddimao_10mmerged.tif" # 执行处理 stats = process_rasters(shp_path, raster_path, output_path) # 打印统计信息 if stats: for key, value in stats.items(): print(f"{key}: {value}") else: print("处理失败,无统计信息可显示")