import os
import numpy as np
import torch
from torch import nn
import torch.optim as optim
from torch.distributions import Normal
from win32com.client.gencache import EnsureDispatch, EnsureModule
from win32com.client import CastTo, constants


class ZemaxEnv:
    def __init__(self, zmx_file, target_rms=7.0):
        EnsureModule('{EA433010-2BAC-43C4-857C-7AEAC4A8CCE0}', lcid=0, major=1, minor=0)
        EnsureModule('{F66684D7-AAFE-4A62-9156-FF7A7853F764}', lcid=0, major=1, minor=0)
        self.TheConnection = EnsureDispatch("ZOSAPI.ZOSAPI_Connection")
        self.TheApplication = self.TheConnection.CreateNewApplication()
        self.TheSystem = self.TheApplication.PrimarySystem
        if not self.TheSystem.LoadFile(zmx_file, False):
            raise Exception("文件加载失败")
        self.TheLDE = self.TheSystem.LDE
        self.TheMFE = self.TheSystem.MFE
        self.TheSystem.UpdateMode = constants.LensUpdateMode_AllWindows
        self.cpu_cores = os.cpu_count()

        self.target_rms = target_rms
        self.weight_rows = [2, 4, 6, 8, 10, 11, 20, 25]
        self.value_rows = {
            'TOTR': 1,
            'RSCE': 3,
            'MTFA_C': 5,
            'MTFA_E': 7,
            'MTFA_E2': 9,
            'EFFL': 11
        }

        self.num_weights = len(self.weight_rows)
        self.action_dim = self.num_weights + 1
        self.state_dim = self.num_weights + 3 + 6

        self.CRASH_RMS_THRESHOLD = 1e-6
        self.CRASH_PENALTY = -100.0
        self.init_weights = {}
        for row_num in self.weight_rows:
            weight = self.get_operand_weight(row_num)
            self.init_weights[row_num] = weight
        self.RMS_WEIGHT = 10.0
        self.SUCCESS_BONUS = 50.0
        self.spot_window = None
        self.spot_interface = None
        self._init_spot_window()

    def _init_spot_window(self):
        try:
            self.spot_window = self.TheSystem.Analyses.New_Analysis(constants.AnalysisIDM_StandardSpot)
            self.spot_interface = CastTo(self.spot_window, target='IA_')
            self.spot_interface.ApplyAndWaitForCompletion()
        except:
            raise Exception("光斑窗口初始化失败")

    def close(self):
        try:
            if self.spot_window is not None:
                self.spot_window.Close()
        except:
            pass
        self.TheApplication.CloseApplication()

    def run_local_optimization(self, timeout=240):
        try:
            opt = self.TheSystem.Tools.OpenLocalOptimization()
            if opt is None:
                return False
            opt.Algorithm = constants.OptimizationAlgorithm_DampedLeastSquares
            opt.NumberOfCores = self.cpu_cores
            opt_cast = CastTo(opt, target="ISystemTool")
            opt_cast.RunAndWaitWithTimeout(timeout)
            opt_cast.Cancel()
            opt_cast.WaitForCompletion()
            self.TheSystem.UpdateStatus()
            return True
        except:
            return False
        finally:
            try:
                opt_cast.Close()
            except:
                pass
            try:
                opt.Close()
            except:
                pass

    def run_global_optimization(self, timeout=240):
        try:
            opt = self.TheSystem.Tools.OpenGlobalOptimization()
            if opt is None:
                return False
            opt.Algorithm = constants.OptimizationAlgorithm_DampedLeastSquares
            opt.NumberOfCores = self.cpu_cores
            opt.NumberToSave = constants.OptimizationSaveCount_Save_10
            opt_cast = CastTo(opt, target="ISystemTool")
            opt_cast.RunAndWaitWithTimeout(timeout)
            opt_cast.Cancel()
            opt_cast.WaitForCompletion()
            self.TheSystem.UpdateStatus()
            return True
        except:
            return False
        finally:
            try:
                opt_cast.Close()
            except:
                pass
            try:
                opt.Close()
            except:
                pass

    def run_hammer_optimization(self, timeout=240):
        try:
            opt = self.TheSystem.Tools.OpenHammerOptimization()
            if opt is None:
                return False
            opt.NumberOfCores = self.cpu_cores
            opt_cast = CastTo(opt, target="ISystemTool")
            opt_cast.RunAndWaitWithTimeout(timeout)
            opt_cast.Cancel()
            opt_cast.WaitForCompletion()
            self.TheSystem.UpdateStatus()
            return True
        except:
            return False
        finally:
            try:
                opt_cast.Close()
            except:
                pass
            try:
                opt.Close()
            except:
                pass

    def get_operand_weight(self, row_num):
        try:
            row = self.TheMFE.GetOperandAt(row_num)
            return row.Weight if hasattr(row, 'Weight') else 1.0
        except:
            return 1.0

    def set_operand_weight(self, row_num, weight):
        try:
            row = self.TheMFE.GetOperandAt(row_num)
            row.Weight = max(1.0, weight)
            self.TheMFE.CalculateMeritFunction()
            self.TheSystem.UpdateStatus()
            return True
        except:
            return False

    def get_operand_value(self, row_num):
        try:
            self.TheMFE.CalculateMeritFunction()
            self.TheSystem.UpdateStatus()
            row = self.TheMFE.GetOperandAt(row_num)
            if hasattr(row, 'Value'):
                val = row.Value
                if val is not None and not np.isnan(val) and not np.isinf(val):
                    return val
            if row_num in [8, 9]:
                return 0.0
            else:
                return 999.9
        except:
            if row_num in [8, 9]:
                return 0.0
            else:
                return 999.9

    def reset(self):
        for row_num in self.weight_rows:
            self.set_operand_weight(row_num, self.init_weights[row_num])
        self.TheSystem.UpdateStatus()
        return self._get_state()

    def _get_state(self):
        weights = []
        for row_num in self.weight_rows:
            weights.append(self.get_operand_weight(row_num))
        rms_list = self._get_three_rms()
        constraints = [
            self.get_operand_value(self.value_rows['TOTR']),
            self.get_operand_value(self.value_rows['RSCE']),
            self.get_operand_value(self.value_rows['MTFA_C']),
            self.get_operand_value(self.value_rows['MTFA_E']),
            self.get_operand_value(self.value_rows['MTFA_E2']),
            self.get_operand_value(self.value_rows['EFFL'])
        ]
        return np.array(weights + list(rms_list) + constraints, dtype=np.float32)

    def _get_three_rms(self):
        try:
            self.spot_interface.ApplyAndWaitForCompletion()
            rms_values = [100.0, 100.0, 100.0]
            for i_field in range(3):
                wvl_rms = []
                for i_wvl in range(3):
                    try:
                        val = self.spot_interface.GetResults().SpotData.GetRMSSpotSizeFor(i_field + 1, i_wvl + 1)
                        if val is not None and val > self.CRASH_RMS_THRESHOLD:
                            wvl_rms.append(val)
                    except:
                        pass
                if wvl_rms:
                    rms_values[i_field] = np.mean(wvl_rms)
            return np.array(rms_values)
        except:
            return np.array([100.0, 100.0, 100.0])

    def step(self, action):
        # 注意:此处使用了 self.opt_timeout,但在 __init__ 中未定义该变量,
        # 您可能需要在初始化时添加 self.opt_timeout = timeout 或者直接使用默认值
        weight_deltas = action[0:self.num_weights] * 0.1
        opt_action = (action[-1] + 1) / 2

        for i in range(self.num_weights):
            row_num = self.weight_rows[i]
            current_weight = self.get_operand_weight(row_num)
            new_weight = current_weight + weight_deltas[i]
            self.set_operand_weight(row_num, new_weight)

        if opt_action < 0.33:
            opt_success = self.run_local_optimization(timeout=self.opt_timeout)
        elif opt_action < 0.66:
            opt_success = self.run_global_optimization(timeout=self.opt_timeout)
        else:
            opt_success = self.run_hammer_optimization(timeout=self.opt_timeout)

        next_state = self._get_state()

        rms_list = next_state[self.num_weights:self.num_weights + 3]
        print(rms_list)
        new_rms_mean = np.mean(rms_list)

        is_crashed = new_rms_mean < self.CRASH_RMS_THRESHOLD
        if is_crashed:
            reward = self.CRASH_PENALTY
            done = False
        else:
            reward = 0.0
            done = False

        reward -= new_rms_mean * self.RMS_WEIGHT
        if new_rms_mean < self.target_rms:
            reward += self.SUCCESS_BONUS
            done = True

        info = {
            "rms_mean": new_rms_mean,
            "is_crashed": is_crashed,
            "opt_success": opt_success
        }

        return next_state, reward, done, info


class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc_shared = nn.Sequential(
            nn.Linear(state_dim, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=128),
            nn.ReLU()
        )
        self.fc_actor = nn.Linear(in_features=128, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))
        # 代码在此处截断,通常后面还有 Critic 的定义
        # self.fc_critic = nn.Linear(in_features=128, 1)

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐