参考github代码链接:https://github.com/shx951104/remote-sensing-images-fusion/blob/b4b4147c7896468516bd84c544a98270cd26589b/starfm_torch.py
稍稍做了修改
输入:
一组同一日期的高分辨率和低分辨率数据s1和l1
t2时刻的低分辨率数据l2
输出t2时刻的高分辨率数据s2

# -*- coding: utf-8 -*-
"""
Created on Tue Mar 17 15:15:36 2020

@author: Administrator
"""


import numpy as np
import torch
import torch.nn as nn
import time
#import skimage.measure as  sm
import skimage.metrics  as  sm
import cv2
from osgeo import gdal,gdalconst
import matplotlib.pyplot as plt
import skimage.io as io
from skimage.transform import resize
from utils import *

###weight caculate tools######################################################
def weight_caculate(data):
    return  torch.log((abs(data)*10000+1.00001))

def caculate_weight(s1l1,l1l2):
    #atmos difference
    ws1l1=weight_caculate(s1l1 )
    #time deference
    wl1l2=weight_caculate(l1l2 )
    return  ws1l1*wl1l2

###space distance caculate tool################################################
def indexdistance(window):
    #one window, one distance weight matrix
    [distx,disty]=np.meshgrid(np.arange(window[0]),np.arange(window[1]))
    centerlocx,centerlocy=(window[0]-1)//2,(window[1]-1)//2
    dist=1+(((distx-centerlocx)**2+(disty-centerlocy)**2)**0.5)/((window[0]-1)//2)
    return  dist

###threshold select tool######################################################
def weight_bythreshold(weight,data,threshold):
    #make weight tensor
    weight[data<=threshold]=1
    return  weight
def weight_bythreshold_allbands(weight,l1m1,m1m2,thresholdmax):
    #make weight tensor
    weight[l1m1<=thresholdmax[0]]=1
    weight[m1m2<=thresholdmax[1]]=1
    allweight=(weight.sum(0).view(1,weight.shape[1],weight.shape[2]))/weight.shape[0]
    allweight[allweight!=1]=0
    return  allweight


###initial similar pixels tools################################################
def spectral_similar_threshold(clusters,NIR,red):
    thresholdNIR=NIR.std()*2/clusters
    thresholdred=red.std()*2/clusters
    return  (thresholdNIR,thresholdred)  

def caculate_similar(l1,threshold,window):
    #read l1
    device= torch.device( "cpu")
    l1=nn.functional.unfold(l1,window)    
    #caculate similar
    weight=torch.zeros(l1.shape,dtype=torch.float32).to(device)  
    centerloc=( l1.size()[1]-1)//2
    weight=weight_bythreshold(weight,abs(l1-l1[:,centerloc:centerloc+1,:]) ,threshold)
    return weight

def classifier(l1):
    '''not used'''
    return

###similar pixels filter tools#################################################
def allband_arrayindex(arraylist,indexarray,rawindexshape):
    device= torch.device( "cpu")
    shape=arraylist[0].shape
    datalist=[]
    for array in arraylist:
        newarray=torch.zeros(rawindexshape,dtype=torch.float32).to(device)
        for band in range(shape[1]):
            newarray[0,band]=array[0,band][indexarray]
        datalist.append(newarray)
    return  datalist

def similar_filter(datalist,sital,sitam):
    [l1,m1,m2]=datalist
    l1m1=abs(l1-m1)
    m1m2=abs(m2-m1)
    #####
    l1m1=nn.functional.unfold(l1m1,(1,1)).max(1)[0]+(sital**2+sitam**2)**0.5
    m1m2=nn.functional.unfold(m1m2,(1,1)).max(1)[0]+(sitam**2+sitam**2)**0.5
    return (l1m1,m1m2)

###starfm for onepart##########################################################
def starfm_onepart(datalist,similar,thresholdmax,window,outshape,dist):
    #####param and data 
    [l1,m1,m2]=datalist
    bandsize=l1.shape[1]
    outshape=outshape
    blocksize=outshape[0]*outshape[1]
    device= torch.device( "cpu")
    #####img to col
    l1=nn.functional.unfold(l1,window)
    m1=nn.functional.unfold(m1,window)
    m2=nn.functional.unfold(m2,window)
    l1=l1.view(bandsize,-1,blocksize)
    m1=m1.view(bandsize,-1,blocksize)
    m2=m2.view(bandsize,-1,blocksize)   
    l1m1=abs(l1-m1)
    m1m2=abs(m2-m1)
    #####caculate weights
    #time and space weight 时间和距离权重
    w=caculate_weight(l1m1,m1m2)
    w=1/(w*dist)
    #similar pixels: 1:by threshold 2:by classifier
    wmask=torch.zeros(l1.shape,dtype=torch.float32).to(device)  
    
    #filter similar pixels  for each band: (bandsize,windowsize,blocksize)
    #wmasknew=weight_bythreshold(wmask,l1m1,thresholdmax[0]) 
    #wmasknew=weight_bythreshold(wmasknew,m1m2,thresholdmax[1])    
    
    #filter similar pixels for all bands: (1,windowsize,blocksize)
    wmasknew=weight_bythreshold_allbands(wmask,l1m1,m1m2,thresholdmax) 
    #mask
    w=w*wmasknew*similar
    #normili
    w=w/(w.sum(1).view(w.shape[0],1,w.shape[2]))
    #####predicte and trans
    #predicte l2
    l2=(l1+m2-m1)*w
    l2=l2.sum(1).reshape(1,bandsize,l2.shape[2])
    #col to img
    l2=nn.functional.fold(l2.view(1,-1,blocksize),outshape,(1,1))
    return l2
###starfm for allpart#########################################################
def starfm_main(s1r,l1r,l2r,
                param={'part_shape':(140,140),
               'window_size':(31,31),
               'clusters':5,
               'NIRindex':3,'redindex':2,
               'sital':0.001,'sitam':0.001}):
    #get start time
    time_start=time.time()  
    device= torch.device( "cpu")
    #read parameters
    parts_shape=param['part_shape']
    window=param['window_size']
    clusters=param['clusters']
    NIRindex=param['NIRindex']
    redindex=param['redindex']
    sital=param['sital']
    sitam=param['sitam']
    #caculate initial similar pixels threshold,计算初始,相似像素的阈值
    threshold=spectral_similar_threshold(clusters,s1r[:,NIRindex:NIRindex+1],s1r[:,redindex:redindex+1])    
    print('similar threshold (NIR,red)',threshold)
    ####shape
    imageshape=(s1r.shape[1],s1r.shape[2],s1r.shape[3])
    print('datashape:',imageshape)
    row=imageshape[1]//parts_shape[0]+1
    col=imageshape[2]//parts_shape[1]+1
    padrow=window[0]//2
    padcol=window[1]//2 
    #####padding constant for conv;STARFM use Inverse distance weight(1/w),better to avoid 0 and NAN(1/0),or you can use another distance measure
    constant1=10
    constant2=20
    constant3=30
    s1=torch.nn.functional.pad( s1r,(padrow,padcol,padrow,padcol),'constant', constant1)
    l1=torch.nn.functional.pad( l1r,(padrow,padcol,padrow,padcol),'constant', constant2)
    l2=torch.nn.functional.pad( l2r,(padrow,padcol,padrow,padcol),'constant', constant3)
    #split parts , get index and  run for every part
    row_part=np.array_split( np.arange(imageshape[1]), row , axis = 0) #没懂rowpart是干嘛的
    col_part=np.array_split( np.arange(imageshape[2]),  col, axis = 0) 
    print('Split into {} parts,row number: {},col number: {}'.format(len(row_part)*len(row_part),len(row_part),len(row_part)))
    dist=nn.functional.unfold(torch.tensor(  indexdistance(window),dtype=torch.float32).reshape(1,1,window[0],window[1]),window).to(device)

    for rnumber,row_index in enumerate(row_part):
        for cnumber,col_index in enumerate(col_part):
            ####run for part: (rnumber,cnumber)
            print('now for part{}'.format((rnumber,cnumber)))
            ####output index
            rawindex=np.meshgrid(row_index,col_index)
            ####output shape
            rawindexshape=(col_index.shape[0],row_index.shape[0])
            ####the real parts_index ,for reading the padded data 
            row_pad=np.arange(row_index[0],row_index[-1]+window[0])
            col_pad=np.arange(col_index[0],col_index[-1]+window[1])    
            padindex=np.meshgrid(row_pad,col_pad)
            padindexshape=(col_pad.shape[0],row_pad.shape[0])
            ####caculate initial similar pixels
            NIR_similar=caculate_similar(s1[0,NIRindex][ padindex ].view(1,1,padindexshape[0],padindexshape[1]),threshold[0],window)   
            red_similar=caculate_similar(s1[0,redindex][ padindex ].view(1,1,padindexshape[0],padindexshape[1]),threshold[1],window)  
            similar=NIR_similar*red_similar      
            ####caculate threshold used for similar_pixels_filter  
            thresholdmax = similar_filter( allband_arrayindex([s1r,l1r,l2r],rawindex,(1,imageshape[0],rawindexshape[0],rawindexshape[1])),
                                        sital,sitam)
            ####Splicing each col at rnumber-th row
            if cnumber==0:
                rowdata=starfm_onepart( allband_arrayindex([s1,l1,l2],padindex,(1,imageshape[0],padindexshape[0],padindexshape[1])),
                                       similar,thresholdmax,window,rawindexshape,dist
                                       )  
                
            else:
                rowdata=torch.cat( (rowdata,
                                    starfm_onepart( allband_arrayindex([s1,l1,l2],padindex,(1,imageshape[0],padindexshape[0],padindexshape[1])),
                                                   similar,thresholdmax,window,rawindexshape,dist)  ) ,2) 
        ####Splicing each row        
        if rnumber==0:
            l2_fake=rowdata
        else:            
            l2_fake=torch.cat((l2_fake,rowdata),3)
   
    l2_fake=l2_fake.transpose(3,2)
    #time cost
    time_end=time.time()    
    print('now over,use time {:.4f}'.format(time_end-time_start))  
    return l2_fake


def test():
    ##three band datas(sorry,just find them at home,i cant recognise the spectral response range of each band,'NIR' and 'red' are only examples)
    l1file='E:\\TRA\\lake\\l8_contrast\\LC08_123039_20201022.tif'
    l2file='E:\\TRA\\lake\\l8_contrast\\LC08_123039_20201225.tif'
    s1file='E:\\TRA\\lake\\l8_contrast\\20201026T025821_20201026T025817_T50RKU.tif'
    s2file='E:\\TRA\\lake\\l8_contrast\\20201225T030131_20201225T030129_T50RKU.tif'
    ##param
    param={'part_shape':(75,75),
           'window_size':(31,31),
           'clusters':5,
           'NIRindex':1,'redindex':0,
           'sital':0.001,'sitam':0.001}
    
    ##read images from files(numpy)
    s1=imgread(s1file)[:4,:,:]
    l1=imgread(l1file)
    l2=imgread(l2file)
    s2_ground_truth=imgread(s2file)    
    l1 = resize(l1, s1.shape,order=1)
    l2 = resize(l2, s1.shape,order=1)
    
    ##numpy to tensor
    shape=s1.shape
    s1_resize=torch.tensor(s1.reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32) #1,bands,h,w
    l1_resize=torch.tensor(resize(l1,(shape[0],shape[1],shape[2])).reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32)
    l2_resize=torch.tensor(resize(l2,(shape[0],shape[1],shape[2])).reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32)
    device= torch.device( "cpu")
    s1_resize=s1_resize.to(device)
    l1_resize=l1_resize.to(device)
    l2_resize=l2_resize.to(device)      
    
    ##predicte(tensor input —> tensor output)
    s2_fake=starfm_main(s1_resize,l1_resize,l2_resize,param)
    print(s2_fake.shape)
    
    ##tensor to numpy
    if device.type=='cuda':
        s2_fake=s2_fake[0].cpu().numpy()
    else:
        s2_fake=s2_fake[0].numpy()    
    
    ##show results 
    #transform:(chanel,H,W) to (H,W,chanel)
    s2_fake=s2_fake.transpose(1,2,0)
    s2_ground_truth=s2_ground_truth.transpose(1,2,0)
    s1=s1.transpose(1,2,0)
    l1=l1.transpose(1,2,0)
    l2=l2.transpose(1,2,0)
    #plot
    plt.figure('landsat:t1')
    plt.imshow(l1) 
    plt.figure('landsat:t2')
    plt.imshow(l2) 
    plt.figure('sentinel:t1')
    plt.imshow(s1) 
    plt.figure('sentinel:t2_fake')
    plt.imshow(s2_fake)
    plt.figure('sentinel:t2_groundtrue')
    plt.imshow(s2_ground_truth)    
    plt.show()
    ##evaluation
    # driver = gdal.GetDriverByName("GTiff")
    # dataset = driver.Create('result.tif', im_width, im_height, im_bands, datatype)
    # writetif(l2_fake,'l2_fake.tif','')
    writetif(s2_fake,'s2_fake.tif','E:\\TRA\\lake\\s2\\20200315T025539_20200315T030729_T50RKU.tif')
    psnr  = 10. * np.log10(1. / np.mean((s2_fake - s2_ground_truth) ** 2))
    ssim1=sm.structural_similarity(s2_fake,s2_ground_truth,data_range=1,multichannel=True)
    sim2=sm.structural_similarity(s1,s2_ground_truth,data_range=1,multichannel=True)
    # ssim3=sm.structural_similarity(s1+l2-l1,s2_ground_truth,data_range=1,multichannel=True)
    print('with-similarpixels ssim: {:.4f};landsat_t1 ssim: {:.4f};non-similarpixels ssim: {:.4f}'.format(ssim1,ssim2,ssim3))    
    return

def writetif(dataset,target_file,reference_file):
    reference = gdal.Open(reference_file,gdalconst.GA_ReadOnly)
    band_count = dataset.shape[2]  # 波段数
    band1 = dataset[0]
    # data_type = band1.DataType
    target = gdal.GetDriverByName("GTiff").Create(target_file, xsize=dataset.shape[1], ysize=dataset.shape[0], bands=band_count,
                                       eType= reference.GetRasterBand(1).DataType)
    geotrans = list(reference.GetGeoTransform())
    target.SetProjection(reference.GetProjection())  # 设置投影坐标
    target.SetGeoTransform(geotrans)  # 设置地理变换参数
    total = band_count + 1
    for index in range(1, total):
        # data = dataset.GetRasterBand(index).ReadAsArray(buf_xsize=dataset.shape[0], buf_ysize=dataset.shape[1])
        out_band = target.GetRasterBand(index)
        # out_band.SetNoDataValue(dataset.GetRasterBand(index).GetNoDataValue())
        out_band.WriteArray(dataset[:,:,index-1])  # 写入数据到新影像中
        out_band.FlushCache()
        out_band.ComputeBandStats(False)  # 计算统计信息
    print("正在写入完成")
    del dataset 
if __name__ == "__main__":
    test()
    

    

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐