From b52051f7e2ce5c5079fb6970ca41e45f3cc21c86 Mon Sep 17 00:00:00 2001 From: Yujing Huang Date: Thu, 18 Dec 2025 08:44:58 -0500 Subject: [PATCH] create instance variable 'templateFileName' for class Samseg and SamsegLongitudinal - set self.templateFileName = os.path.join(self.atlasDir, 'template.nii.gz') if it doesn't exist, set self.templateFileName = os.path.join(self.atlasDir, 'template.nii') - replace 'os.path.join(self.atlasDir, 'template.nii')' referencing with self.templateFileName --- samseg/Samseg.py | 10 +++++++--- samseg/SamsegLongitudinal.py | 14 ++++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/samseg/Samseg.py b/samseg/Samseg.py index f0bc947..d95ab0a 100644 --- a/samseg/Samseg.py +++ b/samseg/Samseg.py @@ -67,6 +67,10 @@ def __init__(self, raise ValueError('number of mode names does not match number of input images') self.modeNames = modeNames + self.templateFileName = os.path.join(self.atlasDir, 'template.nii.gz') + if (not os.path.isfile(self.templateFileName)): + self.templateFileName = os.path.join(self.atlasDir, 'template.nii') + # Eugenio: there's a bug in ITK that will cause kvlImage to fail if it contqins the string "recon" ... # If this problem is not exclusive to the photo mode (RGB), we should move this chunk of code outside the if # While at it, we also create a grayscale version and a version with a bit of noise around the cerebrum (so the @@ -98,7 +102,7 @@ def __init__(self, # Initialize some objects self.affine = Affine( imageFileName=self.imageFileNames[0], meshCollectionFileName=os.path.join(self.atlasDir, 'atlasForAffineRegistration.txt.gz'), - templateFileName=os.path.join(self.atlasDir, 'template.nii.gz' ) ) + templateFileName=self.templateFileName ) self.probabilisticAtlas = ProbabilisticAtlas() # Get full model specifications and optimization options (using default unless overridden by user) @@ -323,7 +327,7 @@ def preProcess(self): else: self.imageBuffers, self.transform, self.voxelSpacing, self.cropping = readCroppedImages( self.imageFileNames, - os.path.join(self.atlasDir, 'template.nii.gz'), + self.templateFileName, self.imageToImageTransformMatrix ) @@ -597,7 +601,7 @@ def saveWarpField(self, filename): # extract geometries source = sf.load_volume(self.imageFileNames[0]).geom - target = sf.load_volume(os.path.join(self.atlasDir, 'template.nii.gz')).geom + target = sf.load_volume(self.templateFileName).geom # extract vox-to-vox template transform # TODO: Grabbing the transform from the saved .mat file in either the cross or base diff --git a/samseg/SamsegLongitudinal.py b/samseg/SamsegLongitudinal.py index c745073..f4ed2f1 100644 --- a/samseg/SamsegLongitudinal.py +++ b/samseg/SamsegLongitudinal.py @@ -128,6 +128,10 @@ def __init__(self, for tp in range(self.numberOfTimepoints): self.tpToBaseTransforms.append(np.eye(4)) + self.templateFileName = os.path.join(self.atlasDir, 'template.nii.gz') + if (not os.path.isfile(self.templateFileName)): + self.templateFileName = os.path.join(self.atlasDir, 'template.nii') + # Set image-to-image matrix if provided self.imageToImageTransformMatrix = imageToImageTransformMatrix @@ -209,12 +213,11 @@ def constructAndRegisterSubjectSpecificTemplate(self, initTransformFile=None): if self.imageToImageTransformMatrix is None: # Affine atlas registration to sst - templateFileName = os.path.join(self.atlasDir, 'template.nii.gz') affineRegistrationMeshCollectionFileName = os.path.join(self.atlasDir, 'atlasForAffineRegistration.txt.gz') affine = Affine(imageFileName=self.sstFileNames[0], meshCollectionFileName=affineRegistrationMeshCollectionFileName, - templateFileName=templateFileName) + templateFileName=self.templateFileName) self.imageToImageTransformMatrix, _ = affine.registerAtlas(savePath=sstDir, visualizer=self.visualizer, initTransform=initTransform) @@ -229,8 +232,7 @@ def preProcess(self): # # ======================================================================================= - templateFileName = os.path.join(self.atlasDir, 'template.nii.gz') - self.sstModel.imageBuffers, self.sstModel.transform, self.sstModel.voxelSpacing, self.sstModel.cropping = readCroppedImages(self.sstFileNames, templateFileName, self.imageToImageTransformMatrix) + self.sstModel.imageBuffers, self.sstModel.transform, self.sstModel.voxelSpacing, self.sstModel.cropping = readCroppedImages(self.sstFileNames, self.templateFileName, self.imageToImageTransformMatrix) self.imageBuffersList = [] self.voxelSpacings = [] @@ -241,7 +243,7 @@ def preProcess(self): self.imageBuffersList = [] for imageFileNames in self.imageFileNamesList: - imageBuffers, _, _, _ = readCroppedImages(imageFileNames, templateFileName, + imageBuffers, _, _, _ = readCroppedImages(imageFileNames, self.templateFileName, self.imageToImageTransformMatrix) self.imageBuffersList.append(imageBuffers) @@ -290,7 +292,7 @@ def preProcess(self): tmpS = sf.load_volume(os.path.join(self.savePath, "base", "template_coregistered.mgz")) pToTpTransform = tmpTp.geom.world2vox @ self.tpToBaseTransforms[timepointNumber].inv() @ tmpS.geom.vox2world @ self.imageToImageTransformMatrix - imageBuffers, transform, voxelSpacing, cropping = readCroppedImages(imageFileNames, templateFileName, pToTpTransform.matrix) + imageBuffers, transform, voxelSpacing, cropping = readCroppedImages(imageFileNames, self.templateFileName, pToTpTransform.matrix) # self.imageBuffersList.append(imageBuffers)