From d1f87fce9f62846c0e4ad89efe0e6e6789c6bff1 Mon Sep 17 00:00:00 2001 From: Kai Ruan Date: Thu, 20 Nov 2025 14:19:05 +0800 Subject: [PATCH] fix bugs --- psrn/__init__.py | 2 +- psrn/model/operators.py | 2 -- psrn/model/regressor.py | 15 +------- .../GP/model/expr_utils/calculator.py | 1 - .../token_generator/GP/model/ga/agent.py | 1 - psrn/model/token_generator/GP/model/ga/ga.py | 6 ++-- psrn/model/token_generator/gp.py | 2 -- psrn/utils/data.py | 36 ------------------- psrn/utils/evaluate.py | 8 ++--- psrn/utils/exprutils.py | 2 +- psrn/utils/gen_dr_mask.py | 6 +--- test/run_custom_data.py | 1 - 12 files changed, 10 insertions(+), 72 deletions(-) diff --git a/psrn/__init__.py b/psrn/__init__.py index b59281a..64abc40 100644 --- a/psrn/__init__.py +++ b/psrn/__init__.py @@ -1,3 +1,3 @@ from .model.regressor import PSRN_Regressor -from .cli import main as cli_main # 导出 main 函数 +from .cli import main as cli_main __version__ = "0.1.1-beta.3" \ No newline at end of file diff --git a/psrn/model/operators.py b/psrn/model/operators.py index 699f7e5..86c36a0 100644 --- a/psrn/model/operators.py +++ b/psrn/model/operators.py @@ -1,6 +1,4 @@ import torch -import torch.nn as nn -import torch.nn.functional as F # unary operators diff --git a/psrn/model/regressor.py b/psrn/model/regressor.py index ad407ab..eba833c 100644 --- a/psrn/model/regressor.py +++ b/psrn/model/regressor.py @@ -1,12 +1,10 @@ import math -import itertools import re import os import torch import torch.nn as nn -import torch.nn.functional as F import time import gc @@ -549,41 +547,33 @@ def __init__( self.use_dr_mask = use_dr_mask - # 确定 Mask 的文件名 (统一格式: layers_inputs_[Op1_Op2]_mask.npy) ops_str = "_".join(self.operators) file_name_mask = f'{self.n_symbol_layers}_{self.n_inputs}_[{ops_str}]_mask.npy' if self.use_dr_mask: - # 3.1 确定目录策略 - # 策略 A: 如果用户指定了目录,就用用户的 if dr_mask_dir is not None: self.dr_mask_dir = dr_mask_dir else: - # 策略 B: 尝试使用包内的 dr_mask 目录 self.dr_mask_dir = os.path.join(package_root, 'dr_mask') self.dr_mask_path = os.path.join(self.dr_mask_dir, file_name_mask) - # 3.2 检查文件是否存在,不存在则生成 if not os.path.exists(self.dr_mask_path): print(f"[Info] DR Mask not found at: {self.dr_mask_path}") print("Generating mask automatically...") try: - # 尝试在当前指定的目录下生成 (如果是 site-packages 可能会失败) generate_dr_mask_core( n_symbol_layers=self.n_symbol_layers, n_inputs=self.n_inputs, - ops=self.operators, # 直接传 list + ops=self.operators, save_dir=self.dr_mask_dir, device="cuda" if torch.cuda.is_available() else "cpu" ) except (PermissionError, OSError): - # 3.3 权限不足的回退策略 print(f"[Warning] No permission to write to {self.dr_mask_dir} (likely installed in system path).") print("[Info] Fallback: Generating mask in current working directory './dr_mask'") - # 更改目录到当前工作目录 self.dr_mask_dir = os.path.abspath("./dr_mask") self.dr_mask_path = os.path.join(self.dr_mask_dir, file_name_mask) @@ -597,7 +587,6 @@ def __init__( print("Generation finished.") - # 3.4 加载 Mask print(f"Loading drmask from {self.dr_mask_path}") try: dr_mask_np = np.load(self.dr_mask_path) @@ -729,8 +718,6 @@ def fit( ) ) # self.drm - print("len(self.triu_ls):") - print(len(self.triu_ls)) print("=" * 40) print("num of samples:", len(X)) diff --git a/psrn/model/token_generator/GP/model/expr_utils/calculator.py b/psrn/model/token_generator/GP/model/expr_utils/calculator.py index 657cc71..b5b283c 100644 --- a/psrn/model/token_generator/GP/model/expr_utils/calculator.py +++ b/psrn/model/token_generator/GP/model/expr_utils/calculator.py @@ -9,7 +9,6 @@ from ..config import Config from ..expr_utils.utils import time_limit, FinishException -from numpy import inf, seterr from symengine import sympify as se_sympify diff --git a/psrn/model/token_generator/GP/model/ga/agent.py b/psrn/model/token_generator/GP/model/ga/agent.py index 5ed775b..3134bc8 100644 --- a/psrn/model/token_generator/GP/model/ga/agent.py +++ b/psrn/model/token_generator/GP/model/ga/agent.py @@ -9,7 +9,6 @@ from numba import jit import numpy as np -from functools import lru_cache from deap import gp import array diff --git a/psrn/model/token_generator/GP/model/ga/ga.py b/psrn/model/token_generator/GP/model/ga/ga.py index f0f89f1..2dc7ad7 100644 --- a/psrn/model/token_generator/GP/model/ga/ga.py +++ b/psrn/model/token_generator/GP/model/ga/ga.py @@ -32,7 +32,8 @@ def __init__(self, config_s: Config): Initializing the tokens of genetic algorithm """ for num, exp in self.exp_dict.items(): - if not isinstance(exp, Expression): continue + if not isinstance(exp, Expression): + continue if exp.child == 0: pset.renameArguments(**{f'ARG{var_count}': f"exp{num}"}) var_count += 1 @@ -70,7 +71,8 @@ def ga_play(self, pop_init: List[List[int]]) -> List[List[int]]: hof = tools.HallOfFame(20) pops = pop_init pop = self.config_s.gp.pops - if len(pops) >= pop // 2: pops = random.sample(pops, pop // 2) + if len(pops) >= pop // 2: + pops = random.sample(pops, pop // 2) pops = [creator.Individual(tokens_to_deap(p, self.pset)) for p in pops] pops += self.toolbox.population(n=pop - len(pops)) _ = algorithms.eaSimple(pops, self.toolbox, self.config_s.gp.cxpb, self.config_s.gp.mutpb, diff --git a/psrn/model/token_generator/gp.py b/psrn/model/token_generator/gp.py index 36167c0..9920c2b 100644 --- a/psrn/model/token_generator/gp.py +++ b/psrn/model/token_generator/gp.py @@ -7,8 +7,6 @@ from .GP.model.config import Config from .GP.model.pipeline import Pipeline import sympy as sp -import random -import itertools from collections import Counter from ...utils.exprutils import has_nested_func MAX_LEN_SET = 1000 diff --git a/psrn/utils/data.py b/psrn/utils/data.py index e38c93e..1dd67af 100644 --- a/psrn/utils/data.py +++ b/psrn/utils/data.py @@ -87,42 +87,6 @@ def generate_X(ranges, down_sample, distrib="U"): return points -def get_dynamic_data(dataset_name, file_name): - """ - return dataset df, variables name and target name - - Example - ======= - - >>> df, variables_name, target_name = get_dynamic_data('ball','Baseball_train') - >>> variables_name - >>> ['t'] - >>> target_name - >>> 'h' - """ - df = pd.read_csv("./data/" + dataset_name + "/" + file_name + ".csv", header=None) - # NOTE: If use your own dataset, the column name cannot be `C` or `B`, - # because it's used as constant symbol in regressor - # And none of the variables can be capitalized, because there is a Lower case in eval - if dataset_name == "custom": - names = ["x", "y"] - target_name = "y" - elif dataset_name == "emps": - names = ["q", "qdot", "qddot", "tau"] - target_name = "qddot" - elif dataset_name == "roughpipe": - names = ["l", "y", "k"] - target_name = "y" - else: - raise ValueError("dataset_name error") - - df.columns = names - variables_name = names.copy() - variables_name.remove(target_name) - - return df, variables_name, target_name - - def expr_to_Y_pred(expr_sympy, X, variables): functions = { "sin": np.sin, diff --git a/psrn/utils/evaluate.py b/psrn/utils/evaluate.py index 84a7bcc..49c6a65 100644 --- a/psrn/utils/evaluate.py +++ b/psrn/utils/evaluate.py @@ -1,12 +1,8 @@ import numpy as np - import sympy import math -from .data import get_dynamic_data, expr_to_Y_pred - - -from .exprutils import time_limit, TimeoutException +from .exprutils import time_limit def get_sympy_complexity(expr_str): complexity_dict = { @@ -37,7 +33,7 @@ def get_sympy_complexity(expr_str): complexity = eval(ops_visual_str, complexity_dict) return complexity except Exception as e: - + print('ERR in get_sympy_complexity:', e) return 1e99 diff --git a/psrn/utils/exprutils.py b/psrn/utils/exprutils.py index 2464b64..8c2ba5c 100644 --- a/psrn/utils/exprutils.py +++ b/psrn/utils/exprutils.py @@ -1,7 +1,7 @@ from contextlib import contextmanager import threading import _thread -from sympy import Symbol, sin, cos, exp, log, count_ops +from sympy import sin, cos, exp, log import sympy as sp diff --git a/psrn/utils/gen_dr_mask.py b/psrn/utils/gen_dr_mask.py index b5936e3..2fed333 100644 --- a/psrn/utils/gen_dr_mask.py +++ b/psrn/utils/gen_dr_mask.py @@ -4,7 +4,6 @@ import sympy from tqdm import tqdm import click -import ast try: from ..model.models import PSRN @@ -28,10 +27,7 @@ def generate_dr_mask_core(n_symbol_layers, n_inputs, ops, save_dir="./dr_mask", elif ops == "koza_sign": ops = ["Add", "Mul", "Sub", "Div", "Identity", "Sign"] else: - try: - ops = ast.literal_eval(ops) - except: - ops = eval(ops) + ops = eval(ops) if not isinstance(ops, list): raise ValueError(f"Ops must be a list, got {type(ops)}: {ops}") diff --git a/test/run_custom_data.py b/test/run_custom_data.py index 7c486c9..f50d9c7 100644 --- a/test/run_custom_data.py +++ b/test/run_custom_data.py @@ -1,7 +1,6 @@ import os import click import time -import numpy as np import sympy as sp import pandas as pd