遥感时空融合算法一:STARFM (python实现)
·
参考github代码链接:https://github.com/shx951104/remote-sensing-images-fusion/blob/b4b4147c7896468516bd84c544a98270cd26589b/starfm_torch.py
稍稍做了修改
输入:
一组同一日期的高分辨率和低分辨率数据s1和l1
t2时刻的低分辨率数据l2
输出t2时刻的高分辨率数据s2
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 17 15:15:36 2020
@author: Administrator
"""
import numpy as np
import torch
import torch.nn as nn
import time
#import skimage.measure as sm
import skimage.metrics as sm
import cv2
from osgeo import gdal,gdalconst
import matplotlib.pyplot as plt
import skimage.io as io
from skimage.transform import resize
from utils import *
###weight caculate tools######################################################
def weight_caculate(data):
return torch.log((abs(data)*10000+1.00001))
def caculate_weight(s1l1,l1l2):
#atmos difference
ws1l1=weight_caculate(s1l1 )
#time deference
wl1l2=weight_caculate(l1l2 )
return ws1l1*wl1l2
###space distance caculate tool################################################
def indexdistance(window):
#one window, one distance weight matrix
[distx,disty]=np.meshgrid(np.arange(window[0]),np.arange(window[1]))
centerlocx,centerlocy=(window[0]-1)//2,(window[1]-1)//2
dist=1+(((distx-centerlocx)**2+(disty-centerlocy)**2)**0.5)/((window[0]-1)//2)
return dist
###threshold select tool######################################################
def weight_bythreshold(weight,data,threshold):
#make weight tensor
weight[data<=threshold]=1
return weight
def weight_bythreshold_allbands(weight,l1m1,m1m2,thresholdmax):
#make weight tensor
weight[l1m1<=thresholdmax[0]]=1
weight[m1m2<=thresholdmax[1]]=1
allweight=(weight.sum(0).view(1,weight.shape[1],weight.shape[2]))/weight.shape[0]
allweight[allweight!=1]=0
return allweight
###initial similar pixels tools################################################
def spectral_similar_threshold(clusters,NIR,red):
thresholdNIR=NIR.std()*2/clusters
thresholdred=red.std()*2/clusters
return (thresholdNIR,thresholdred)
def caculate_similar(l1,threshold,window):
#read l1
device= torch.device( "cpu")
l1=nn.functional.unfold(l1,window)
#caculate similar
weight=torch.zeros(l1.shape,dtype=torch.float32).to(device)
centerloc=( l1.size()[1]-1)//2
weight=weight_bythreshold(weight,abs(l1-l1[:,centerloc:centerloc+1,:]) ,threshold)
return weight
def classifier(l1):
'''not used'''
return
###similar pixels filter tools#################################################
def allband_arrayindex(arraylist,indexarray,rawindexshape):
device= torch.device( "cpu")
shape=arraylist[0].shape
datalist=[]
for array in arraylist:
newarray=torch.zeros(rawindexshape,dtype=torch.float32).to(device)
for band in range(shape[1]):
newarray[0,band]=array[0,band][indexarray]
datalist.append(newarray)
return datalist
def similar_filter(datalist,sital,sitam):
[l1,m1,m2]=datalist
l1m1=abs(l1-m1)
m1m2=abs(m2-m1)
#####
l1m1=nn.functional.unfold(l1m1,(1,1)).max(1)[0]+(sital**2+sitam**2)**0.5
m1m2=nn.functional.unfold(m1m2,(1,1)).max(1)[0]+(sitam**2+sitam**2)**0.5
return (l1m1,m1m2)
###starfm for onepart##########################################################
def starfm_onepart(datalist,similar,thresholdmax,window,outshape,dist):
#####param and data
[l1,m1,m2]=datalist
bandsize=l1.shape[1]
outshape=outshape
blocksize=outshape[0]*outshape[1]
device= torch.device( "cpu")
#####img to col
l1=nn.functional.unfold(l1,window)
m1=nn.functional.unfold(m1,window)
m2=nn.functional.unfold(m2,window)
l1=l1.view(bandsize,-1,blocksize)
m1=m1.view(bandsize,-1,blocksize)
m2=m2.view(bandsize,-1,blocksize)
l1m1=abs(l1-m1)
m1m2=abs(m2-m1)
#####caculate weights
#time and space weight 时间和距离权重
w=caculate_weight(l1m1,m1m2)
w=1/(w*dist)
#similar pixels: 1:by threshold 2:by classifier
wmask=torch.zeros(l1.shape,dtype=torch.float32).to(device)
#filter similar pixels for each band: (bandsize,windowsize,blocksize)
#wmasknew=weight_bythreshold(wmask,l1m1,thresholdmax[0])
#wmasknew=weight_bythreshold(wmasknew,m1m2,thresholdmax[1])
#filter similar pixels for all bands: (1,windowsize,blocksize)
wmasknew=weight_bythreshold_allbands(wmask,l1m1,m1m2,thresholdmax)
#mask
w=w*wmasknew*similar
#normili
w=w/(w.sum(1).view(w.shape[0],1,w.shape[2]))
#####predicte and trans
#predicte l2
l2=(l1+m2-m1)*w
l2=l2.sum(1).reshape(1,bandsize,l2.shape[2])
#col to img
l2=nn.functional.fold(l2.view(1,-1,blocksize),outshape,(1,1))
return l2
###starfm for allpart#########################################################
def starfm_main(s1r,l1r,l2r,
param={'part_shape':(140,140),
'window_size':(31,31),
'clusters':5,
'NIRindex':3,'redindex':2,
'sital':0.001,'sitam':0.001}):
#get start time
time_start=time.time()
device= torch.device( "cpu")
#read parameters
parts_shape=param['part_shape']
window=param['window_size']
clusters=param['clusters']
NIRindex=param['NIRindex']
redindex=param['redindex']
sital=param['sital']
sitam=param['sitam']
#caculate initial similar pixels threshold,计算初始,相似像素的阈值
threshold=spectral_similar_threshold(clusters,s1r[:,NIRindex:NIRindex+1],s1r[:,redindex:redindex+1])
print('similar threshold (NIR,red)',threshold)
####shape
imageshape=(s1r.shape[1],s1r.shape[2],s1r.shape[3])
print('datashape:',imageshape)
row=imageshape[1]//parts_shape[0]+1
col=imageshape[2]//parts_shape[1]+1
padrow=window[0]//2
padcol=window[1]//2
#####padding constant for conv;STARFM use Inverse distance weight(1/w),better to avoid 0 and NAN(1/0),or you can use another distance measure
constant1=10
constant2=20
constant3=30
s1=torch.nn.functional.pad( s1r,(padrow,padcol,padrow,padcol),'constant', constant1)
l1=torch.nn.functional.pad( l1r,(padrow,padcol,padrow,padcol),'constant', constant2)
l2=torch.nn.functional.pad( l2r,(padrow,padcol,padrow,padcol),'constant', constant3)
#split parts , get index and run for every part
row_part=np.array_split( np.arange(imageshape[1]), row , axis = 0) #没懂rowpart是干嘛的
col_part=np.array_split( np.arange(imageshape[2]), col, axis = 0)
print('Split into {} parts,row number: {},col number: {}'.format(len(row_part)*len(row_part),len(row_part),len(row_part)))
dist=nn.functional.unfold(torch.tensor( indexdistance(window),dtype=torch.float32).reshape(1,1,window[0],window[1]),window).to(device)
for rnumber,row_index in enumerate(row_part):
for cnumber,col_index in enumerate(col_part):
####run for part: (rnumber,cnumber)
print('now for part{}'.format((rnumber,cnumber)))
####output index
rawindex=np.meshgrid(row_index,col_index)
####output shape
rawindexshape=(col_index.shape[0],row_index.shape[0])
####the real parts_index ,for reading the padded data
row_pad=np.arange(row_index[0],row_index[-1]+window[0])
col_pad=np.arange(col_index[0],col_index[-1]+window[1])
padindex=np.meshgrid(row_pad,col_pad)
padindexshape=(col_pad.shape[0],row_pad.shape[0])
####caculate initial similar pixels
NIR_similar=caculate_similar(s1[0,NIRindex][ padindex ].view(1,1,padindexshape[0],padindexshape[1]),threshold[0],window)
red_similar=caculate_similar(s1[0,redindex][ padindex ].view(1,1,padindexshape[0],padindexshape[1]),threshold[1],window)
similar=NIR_similar*red_similar
####caculate threshold used for similar_pixels_filter
thresholdmax = similar_filter( allband_arrayindex([s1r,l1r,l2r],rawindex,(1,imageshape[0],rawindexshape[0],rawindexshape[1])),
sital,sitam)
####Splicing each col at rnumber-th row
if cnumber==0:
rowdata=starfm_onepart( allband_arrayindex([s1,l1,l2],padindex,(1,imageshape[0],padindexshape[0],padindexshape[1])),
similar,thresholdmax,window,rawindexshape,dist
)
else:
rowdata=torch.cat( (rowdata,
starfm_onepart( allband_arrayindex([s1,l1,l2],padindex,(1,imageshape[0],padindexshape[0],padindexshape[1])),
similar,thresholdmax,window,rawindexshape,dist) ) ,2)
####Splicing each row
if rnumber==0:
l2_fake=rowdata
else:
l2_fake=torch.cat((l2_fake,rowdata),3)
l2_fake=l2_fake.transpose(3,2)
#time cost
time_end=time.time()
print('now over,use time {:.4f}'.format(time_end-time_start))
return l2_fake
def test():
##three band datas(sorry,just find them at home,i cant recognise the spectral response range of each band,'NIR' and 'red' are only examples)
l1file='E:\\TRA\\lake\\l8_contrast\\LC08_123039_20201022.tif'
l2file='E:\\TRA\\lake\\l8_contrast\\LC08_123039_20201225.tif'
s1file='E:\\TRA\\lake\\l8_contrast\\20201026T025821_20201026T025817_T50RKU.tif'
s2file='E:\\TRA\\lake\\l8_contrast\\20201225T030131_20201225T030129_T50RKU.tif'
##param
param={'part_shape':(75,75),
'window_size':(31,31),
'clusters':5,
'NIRindex':1,'redindex':0,
'sital':0.001,'sitam':0.001}
##read images from files(numpy)
s1=imgread(s1file)[:4,:,:]
l1=imgread(l1file)
l2=imgread(l2file)
s2_ground_truth=imgread(s2file)
l1 = resize(l1, s1.shape,order=1)
l2 = resize(l2, s1.shape,order=1)
##numpy to tensor
shape=s1.shape
s1_resize=torch.tensor(s1.reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32) #1,bands,h,w
l1_resize=torch.tensor(resize(l1,(shape[0],shape[1],shape[2])).reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32)
l2_resize=torch.tensor(resize(l2,(shape[0],shape[1],shape[2])).reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32)
device= torch.device( "cpu")
s1_resize=s1_resize.to(device)
l1_resize=l1_resize.to(device)
l2_resize=l2_resize.to(device)
##predicte(tensor input —> tensor output)
s2_fake=starfm_main(s1_resize,l1_resize,l2_resize,param)
print(s2_fake.shape)
##tensor to numpy
if device.type=='cuda':
s2_fake=s2_fake[0].cpu().numpy()
else:
s2_fake=s2_fake[0].numpy()
##show results
#transform:(chanel,H,W) to (H,W,chanel)
s2_fake=s2_fake.transpose(1,2,0)
s2_ground_truth=s2_ground_truth.transpose(1,2,0)
s1=s1.transpose(1,2,0)
l1=l1.transpose(1,2,0)
l2=l2.transpose(1,2,0)
#plot
plt.figure('landsat:t1')
plt.imshow(l1)
plt.figure('landsat:t2')
plt.imshow(l2)
plt.figure('sentinel:t1')
plt.imshow(s1)
plt.figure('sentinel:t2_fake')
plt.imshow(s2_fake)
plt.figure('sentinel:t2_groundtrue')
plt.imshow(s2_ground_truth)
plt.show()
##evaluation
# driver = gdal.GetDriverByName("GTiff")
# dataset = driver.Create('result.tif', im_width, im_height, im_bands, datatype)
# writetif(l2_fake,'l2_fake.tif','')
writetif(s2_fake,'s2_fake.tif','E:\\TRA\\lake\\s2\\20200315T025539_20200315T030729_T50RKU.tif')
psnr = 10. * np.log10(1. / np.mean((s2_fake - s2_ground_truth) ** 2))
ssim1=sm.structural_similarity(s2_fake,s2_ground_truth,data_range=1,multichannel=True)
sim2=sm.structural_similarity(s1,s2_ground_truth,data_range=1,multichannel=True)
# ssim3=sm.structural_similarity(s1+l2-l1,s2_ground_truth,data_range=1,multichannel=True)
print('with-similarpixels ssim: {:.4f};landsat_t1 ssim: {:.4f};non-similarpixels ssim: {:.4f}'.format(ssim1,ssim2,ssim3))
return
def writetif(dataset,target_file,reference_file):
reference = gdal.Open(reference_file,gdalconst.GA_ReadOnly)
band_count = dataset.shape[2] # 波段数
band1 = dataset[0]
# data_type = band1.DataType
target = gdal.GetDriverByName("GTiff").Create(target_file, xsize=dataset.shape[1], ysize=dataset.shape[0], bands=band_count,
eType= reference.GetRasterBand(1).DataType)
geotrans = list(reference.GetGeoTransform())
target.SetProjection(reference.GetProjection()) # 设置投影坐标
target.SetGeoTransform(geotrans) # 设置地理变换参数
total = band_count + 1
for index in range(1, total):
# data = dataset.GetRasterBand(index).ReadAsArray(buf_xsize=dataset.shape[0], buf_ysize=dataset.shape[1])
out_band = target.GetRasterBand(index)
# out_band.SetNoDataValue(dataset.GetRasterBand(index).GetNoDataValue())
out_band.WriteArray(dataset[:,:,index-1]) # 写入数据到新影像中
out_band.FlushCache()
out_band.ComputeBandStats(False) # 计算统计信息
print("正在写入完成")
del dataset
if __name__ == "__main__":
test()
更多推荐
已为社区贡献3条内容
所有评论(0)