Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion psrn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .model.regressor import PSRN_Regressor
from .cli import main as cli_main # 导出 main 函数
from .cli import main as cli_main
__version__ = "0.1.1-beta.3"
2 changes: 0 additions & 2 deletions psrn/model/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

# unary operators

Expand Down
15 changes: 1 addition & 14 deletions psrn/model/regressor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import math
import itertools
import re

import os

import torch
import torch.nn as nn
import torch.nn.functional as F

import time
import gc
Expand Down Expand Up @@ -549,41 +547,33 @@ def __init__(

self.use_dr_mask = use_dr_mask

# 确定 Mask 的文件名 (统一格式: layers_inputs_[Op1_Op2]_mask.npy)
ops_str = "_".join(self.operators)
file_name_mask = f'{self.n_symbol_layers}_{self.n_inputs}_[{ops_str}]_mask.npy'

if self.use_dr_mask:
# 3.1 确定目录策略
# 策略 A: 如果用户指定了目录,就用用户的
if dr_mask_dir is not None:
self.dr_mask_dir = dr_mask_dir
else:
# 策略 B: 尝试使用包内的 dr_mask 目录
self.dr_mask_dir = os.path.join(package_root, 'dr_mask')

self.dr_mask_path = os.path.join(self.dr_mask_dir, file_name_mask)

# 3.2 检查文件是否存在,不存在则生成
if not os.path.exists(self.dr_mask_path):
print(f"[Info] DR Mask not found at: {self.dr_mask_path}")
print("Generating mask automatically...")

try:
# 尝试在当前指定的目录下生成 (如果是 site-packages 可能会失败)
generate_dr_mask_core(
n_symbol_layers=self.n_symbol_layers,
n_inputs=self.n_inputs,
ops=self.operators, # 直接传 list
ops=self.operators,
save_dir=self.dr_mask_dir,
device="cuda" if torch.cuda.is_available() else "cpu"
)
except (PermissionError, OSError):
# 3.3 权限不足的回退策略
print(f"[Warning] No permission to write to {self.dr_mask_dir} (likely installed in system path).")
print("[Info] Fallback: Generating mask in current working directory './dr_mask'")

# 更改目录到当前工作目录
self.dr_mask_dir = os.path.abspath("./dr_mask")
self.dr_mask_path = os.path.join(self.dr_mask_dir, file_name_mask)

Expand All @@ -597,7 +587,6 @@ def __init__(

print("Generation finished.")

# 3.4 加载 Mask
print(f"Loading drmask from {self.dr_mask_path}")
try:
dr_mask_np = np.load(self.dr_mask_path)
Expand Down Expand Up @@ -729,8 +718,6 @@ def fit(
)
)
# self.drm
print("len(self.triu_ls):")
print(len(self.triu_ls))
print("=" * 40)

print("num of samples:", len(X))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from ..config import Config
from ..expr_utils.utils import time_limit, FinishException
from numpy import inf, seterr

from symengine import sympify as se_sympify

Expand Down
1 change: 0 additions & 1 deletion psrn/model/token_generator/GP/model/ga/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from numba import jit
import numpy as np
from functools import lru_cache
from deap import gp

import array
Expand Down
6 changes: 4 additions & 2 deletions psrn/model/token_generator/GP/model/ga/ga.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def __init__(self, config_s: Config):
Initializing the tokens of genetic algorithm
"""
for num, exp in self.exp_dict.items():
if not isinstance(exp, Expression): continue
if not isinstance(exp, Expression):
continue
if exp.child == 0:
pset.renameArguments(**{f'ARG{var_count}': f"exp{num}"})
var_count += 1
Expand Down Expand Up @@ -70,7 +71,8 @@ def ga_play(self, pop_init: List[List[int]]) -> List[List[int]]:
hof = tools.HallOfFame(20)
pops = pop_init
pop = self.config_s.gp.pops
if len(pops) >= pop // 2: pops = random.sample(pops, pop // 2)
if len(pops) >= pop // 2:
pops = random.sample(pops, pop // 2)
pops = [creator.Individual(tokens_to_deap(p, self.pset)) for p in pops]
pops += self.toolbox.population(n=pop - len(pops))
_ = algorithms.eaSimple(pops, self.toolbox, self.config_s.gp.cxpb, self.config_s.gp.mutpb,
Expand Down
2 changes: 0 additions & 2 deletions psrn/model/token_generator/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from .GP.model.config import Config
from .GP.model.pipeline import Pipeline
import sympy as sp
import random
import itertools
from collections import Counter
from ...utils.exprutils import has_nested_func
MAX_LEN_SET = 1000
Expand Down
36 changes: 0 additions & 36 deletions psrn/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,42 +87,6 @@ def generate_X(ranges, down_sample, distrib="U"):
return points


def get_dynamic_data(dataset_name, file_name):
"""
return dataset df, variables name and target name

Example
=======

>>> df, variables_name, target_name = get_dynamic_data('ball','Baseball_train')
>>> variables_name
>>> ['t']
>>> target_name
>>> 'h'
"""
df = pd.read_csv("./data/" + dataset_name + "/" + file_name + ".csv", header=None)
# NOTE: If use your own dataset, the column name cannot be `C` or `B`,
# because it's used as constant symbol in regressor
# And none of the variables can be capitalized, because there is a Lower case in eval
if dataset_name == "custom":
names = ["x", "y"]
target_name = "y"
elif dataset_name == "emps":
names = ["q", "qdot", "qddot", "tau"]
target_name = "qddot"
elif dataset_name == "roughpipe":
names = ["l", "y", "k"]
target_name = "y"
else:
raise ValueError("dataset_name error")

df.columns = names
variables_name = names.copy()
variables_name.remove(target_name)

return df, variables_name, target_name


def expr_to_Y_pred(expr_sympy, X, variables):
functions = {
"sin": np.sin,
Expand Down
8 changes: 2 additions & 6 deletions psrn/utils/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import numpy as np

import sympy
import math
from .data import get_dynamic_data, expr_to_Y_pred


from .exprutils import time_limit, TimeoutException

from .exprutils import time_limit

def get_sympy_complexity(expr_str):
complexity_dict = {
Expand Down Expand Up @@ -37,7 +33,7 @@ def get_sympy_complexity(expr_str):
complexity = eval(ops_visual_str, complexity_dict)
return complexity
except Exception as e:

print('ERR in get_sympy_complexity:', e)
return 1e99


Expand Down
2 changes: 1 addition & 1 deletion psrn/utils/exprutils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager
import threading
import _thread
from sympy import Symbol, sin, cos, exp, log, count_ops
from sympy import sin, cos, exp, log
import sympy as sp


Expand Down
6 changes: 1 addition & 5 deletions psrn/utils/gen_dr_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import sympy
from tqdm import tqdm
import click
import ast

try:
from ..model.models import PSRN
Expand All @@ -28,10 +27,7 @@ def generate_dr_mask_core(n_symbol_layers, n_inputs, ops, save_dir="./dr_mask",
elif ops == "koza_sign":
ops = ["Add", "Mul", "Sub", "Div", "Identity", "Sign"]
else:
try:
ops = ast.literal_eval(ops)
except:
ops = eval(ops)
ops = eval(ops)

if not isinstance(ops, list):
raise ValueError(f"Ops must be a list, got {type(ops)}: {ops}")
Expand Down
1 change: 0 additions & 1 deletion test/run_custom_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import click
import time
import numpy as np
import sympy as sp
import pandas as pd

Expand Down