154 lines
4.8 KiB
Python
154 lines
4.8 KiB
Python
import numpy as np
|
||
import sys
|
||
import os
|
||
|
||
# 尝试多种方式导入GDAL
|
||
try:
|
||
# 尝试直接导入GDAL
|
||
import gdal
|
||
import ogr
|
||
import osr
|
||
|
||
print("成功导入GDAL库")
|
||
except ImportError:
|
||
try:
|
||
# 尝试从osgeo导入
|
||
from osgeo import gdal, ogr, osr
|
||
|
||
print("从osgeo成功导入GDAL库")
|
||
except ImportError:
|
||
print("无法导入GDAL库,请确保已安装GDAL及其Python绑定")
|
||
sys.exit(1)
|
||
|
||
|
||
def process_rasters(shp_path, raster_path, output_path):
|
||
"""
|
||
使用shp文件裁剪栅格,保留不重叠部分设为value=8,然后与原栅格合并
|
||
|
||
参数:
|
||
shp_path: shapefile的路径
|
||
raster_path: 原始栅格文件路径 (value 1-7)
|
||
output_path: 输出栅格文件路径
|
||
"""
|
||
try:
|
||
# 设置GDAL错误处理
|
||
gdal.UseExceptions()
|
||
|
||
# 打开栅格文件
|
||
raster = gdal.Open(raster_path)
|
||
if raster is None:
|
||
print(f"无法打开栅格文件: {raster_path}")
|
||
return None
|
||
|
||
# 获取栅格信息
|
||
print(f"栅格大小: {raster.RasterXSize} x {raster.RasterYSize}, {raster.RasterCount} 波段")
|
||
geo_transform = raster.GetGeoTransform()
|
||
projection = raster.GetProjection()
|
||
print(f"栅格投影: {projection}")
|
||
|
||
# 读取栅格数据
|
||
band = raster.GetRasterBand(1)
|
||
raster_data = band.ReadAsArray()
|
||
nodata = band.GetNoDataValue()
|
||
if nodata is None:
|
||
nodata = 0
|
||
|
||
# 打开shapefile
|
||
driver = ogr.GetDriverByName("ESRI Shapefile")
|
||
try:
|
||
vector = driver.Open(shp_path, 0) # 0表示只读模式
|
||
except Exception as e:
|
||
print(f"无法打开shapefile: {e}")
|
||
return None
|
||
|
||
if vector is None:
|
||
print(f"无法打开shapefile: {shp_path}")
|
||
return None
|
||
|
||
# 获取shapefile信息
|
||
layer = vector.GetLayer()
|
||
feature_count = layer.GetFeatureCount()
|
||
print(f"Shapefile有 {feature_count} 个要素")
|
||
|
||
# 创建栅格掩码
|
||
# 使用GDAL内置的栅格化功能将shapefile转换为栅格
|
||
print("将shapefile栅格化...")
|
||
memory_driver = gdal.GetDriverByName('MEM')
|
||
mask_ds = memory_driver.Create('', raster.RasterXSize, raster.RasterYSize, 1, gdal.GDT_Byte)
|
||
mask_ds.SetGeoTransform(geo_transform)
|
||
mask_ds.SetProjection(projection)
|
||
|
||
# 栅格化
|
||
mask_band = mask_ds.GetRasterBand(1)
|
||
mask_band.Fill(0) # 初始化为0
|
||
|
||
# 将shapefile栅格化到掩码上
|
||
gdal.RasterizeLayer(mask_ds, [1], layer, burn_values=[1])
|
||
|
||
# 读取掩码数据
|
||
mask_data = mask_band.ReadAsArray()
|
||
|
||
# 栅格有值区域的掩码 (value 1-7)
|
||
valid_data_mask = (raster_data >= 1) & (raster_data <= 7)
|
||
|
||
# 获取被shapefile覆盖但不与原始有效栅格区域重叠的部分
|
||
shp_mask = mask_data > 0 # shapefile 覆盖区域
|
||
non_overlapping_mask = shp_mask & (~valid_data_mask) # shapefile 覆盖但不与原始栅格重叠区域
|
||
|
||
# 合并两个栅格
|
||
merged_raster = raster_data.copy()
|
||
merged_raster[non_overlapping_mask] = 8
|
||
|
||
# 创建输出栅格
|
||
driver = gdal.GetDriverByName('GTiff')
|
||
out_ds = driver.Create(output_path, raster.RasterXSize, raster.RasterYSize, 1, band.DataType)
|
||
out_ds.SetGeoTransform(geo_transform)
|
||
out_ds.SetProjection(projection)
|
||
|
||
# 写入数据
|
||
out_band = out_ds.GetRasterBand(1)
|
||
out_band.WriteArray(merged_raster)
|
||
out_band.SetNoDataValue(nodata)
|
||
|
||
# 关闭数据集
|
||
out_ds = None
|
||
mask_ds = None
|
||
raster = None
|
||
vector = None
|
||
|
||
print(f"处理完成,结果保存至: {output_path}")
|
||
|
||
# 返回处理后的统计信息
|
||
stats = {
|
||
"原始栅格有效区域(value 1-7)像素数": int(np.sum(valid_data_mask)),
|
||
"shapefile覆盖区域像素数": int(np.sum(shp_mask)),
|
||
"不重叠区域(value=8)像素数": int(np.sum(non_overlapping_mask)),
|
||
"结果栅格总有效像素数": int(np.sum(merged_raster > 0))
|
||
}
|
||
|
||
return stats
|
||
except Exception as e:
|
||
print(f"发生错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 输入文件路径
|
||
shp_path = r"E:\Data\z18\shandongshp\sd.shp"
|
||
raster_path = r"E:\Data\z18\sd\sd_y10m\sddimao_10m.tif"
|
||
output_path = r"E:\Data\z18\sd\sd_10m\sddimao_10mmerged.tif"
|
||
|
||
# 执行处理
|
||
stats = process_rasters(shp_path, raster_path, output_path)
|
||
|
||
# 打印统计信息
|
||
if stats:
|
||
for key, value in stats.items():
|
||
print(f"{key}: {value}")
|
||
else:
|
||
print("处理失败,无统计信息可显示")
|
||
|