From 7d02dbbb74372949e8d86a4b646e530478377803 Mon Sep 17 00:00:00 2001 From: Filya Geikyan Date: Mon, 23 Sep 2024 15:01:40 +0400 Subject: [PATCH 1/4] safe conversion --- torchtitan/utils/safe.py | 465 ++++++++++++++++++++++++++ torchtitan/utils/text_format_utils.py | 4 + 2 files changed, 469 insertions(+) create mode 100644 torchtitan/utils/safe.py diff --git a/torchtitan/utils/safe.py b/torchtitan/utils/safe.py new file mode 100644 index 00000000..6f51df09 --- /dev/null +++ b/torchtitan/utils/safe.py @@ -0,0 +1,465 @@ +import itertools +import re +from collections import Counter +from contextlib import suppress +from typing import Callable, List, Optional, Union + +import datamol as dm +import numpy as np +from rdkit import Chem +from rdkit.Chem import BRICS + +class SAFEDecodeError(Exception): + """Raised when a string cannot be decoded with the given encoding.""" + pass + +class SAFEEncodeError(Exception): + """Raised when a molecule cannot be encoded using SAFE.""" + pass + + +class SAFEFragmentationError(Exception): + """Raised when a the slicing algorithm return empty bonds.""" + pass + + +class SAFEConverter: + """Molecule line notation conversion from SMILES to SAFE + + A SAFE representation is a string based representation of a molecule decomposition into fragment components, + separated by a dot ('.'). Note that each component (fragment) might not be a valid molecule by themselves, + unless explicitely correct to add missing hydrogens. + + !!! note "Slicing algorithms" + + By default SAFE strings are generated using `BRICS`, however, the following alternative are supported: + + * [Hussain-Rea (`hr`)](https://pubs.acs.org/doi/10.1021/ci900450m) + * [RECAP (`recap`)](https://pubmed.ncbi.nlm.nih.gov/9611787/) + * [RDKit's MMPA (`mmpa`)](https://www.rdkit.org/docs/source/rdkit.Chem.rdMMPA.html) + * Any possible attachment points (`attach`) + + Furthermore, you can also provide your own slicing algorithm, which should return a pair of atoms + corresponding to the bonds to break. + + """ + + SUPPORTED_SLICERS = ["hr", "rotatable", "recap", "mmpa", "attach", "brics"] + __SLICE_SMARTS = { + "hr": ["[*]!@-[*]"], # any non ring single bond + "recap": [ + "[$([C;!$(C([#7])[#7])](=!@[O]))]!@[$([#7;+0;!D1])]", + "[$(C=!@O)]!@[$([O;+0])]", + "[$([N;!D1;+0;!$(N-C=[#7,#8,#15,#16])](-!@[*]))]-!@[$([*])]", + "[$(C(=!@O)([#7;+0;D2,D3])!@[#7;+0;D2,D3])]!@[$([#7;+0;D2,D3])]", + "[$([O;+0](-!@[#6!$(C=O)])-!@[#6!$(C=O)])]-!@[$([#6!$(C=O)])]", + "C=!@C", + "[N;+1;D4]!@[#6]", + "[$([n;+0])]-!@C", + "[$([O]=[C]-@[N;+0])]-!@[$([C])]", + "c-!@c", + "[$([#7;+0;D2,D3])]-!@[$([S](=[O])=[O])]", + ], + "mmpa": ["[#6+0;!$(*=,#[!#6])]!@!=!#[*]"], # classical mmpa slicing smarts + "attach": ["[*]!@[*]"], # any potential attachment point, including hydrogens when explicit + "rotatable": ["[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"], + } + + def __init__( + self, + slicer: Optional[Union[str, List[str], Callable]] = "brics", + require_hs: Optional[bool] = None, + use_original_opener_for_attach: bool = True, + ignore_stereo: bool = False, + ): + """Constructor for the SAFE converter + + Args: + slicer: slicer algorithm to use for encoding. + Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS) + or a custom callable that returns the bond ids that can be sliced. + require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. + `attach` slicer requires adding hydrogens. + use_original_opener_for_attach: whether to use the original branch opener digit when adding back + mapping number to attachment points, or use simple enumeration. + ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. + + """ + self.slicer = slicer + if isinstance(slicer, str) and slicer.lower() in self.SUPPORTED_SLICERS: + self.slicer = self.__SLICE_SMARTS.get(slicer.lower(), slicer) + if self.slicer != "brics" and isinstance(self.slicer, str): + self.slicer = [self.slicer] + if isinstance(self.slicer, (list, tuple)): + self.slicer = [dm.from_smarts(x) for x in self.slicer] + if any(x is None for x in self.slicer): + raise ValueError(f"Slicer: {slicer} cannot be valid") + self.require_hs = require_hs or (slicer == "attach") + self.use_original_opener_for_attach = use_original_opener_for_attach + self.ignore_stereo = ignore_stereo + + @staticmethod + def randomize(mol: dm.Mol, rng: Optional[int] = None): + """Randomize the position of the atoms in a mol. + + Args: + mol: molecules to randomize + rng: optional seed to use + """ + if isinstance(rng, int): + rng = np.random.default_rng(rng) + if mol.GetNumAtoms() == 0: + return mol + atom_indices = list(range(mol.GetNumAtoms())) + atom_indices = rng.permutation(atom_indices).tolist() + return Chem.RenumberAtoms(mol, atom_indices) + + @classmethod + def _find_branch_number(cls, inp: str): + """Find the branch number and ring closure in the SMILES representation using regexp + + Args: + inp: input smiles + """ + inp = re.sub(r"\[.*?\]", "", inp) # noqa + matching_groups = re.findall(r"((?<=%)\d{2})|((? 0: + mol = Chem.FragmentOnBonds( + mol, + bonds, + dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))], + ) + # here we need to be clever and disable rooted atom as the atom with mapping + + frags = list(Chem.GetMolFrags(mol, asMols=True)) + if randomize: + frags = rng.permutation(frags).tolist() + elif canonical: + frags = sorted( + frags, + key=lambda x: x.GetNumAtoms(), + reverse=True, + ) + + frags_str = [] + for frag in frags: + non_map_atom_idxs = [ + atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0 + ] + frags_str.append( + Chem.MolToSmiles( + frag, + isomericSmiles=True, + canonical=True, # needs to always be true + rootedAtAtom=non_map_atom_idxs[0], + ) + ) + + scaffold_str = ".".join(frags_str) + # EN: fix for https://github.com/datamol-io/safe/issues/37 + # we were using the wrong branch number count which did not take into account + # possible change in digit utilization after bond slicing + scf_branch_num = self._find_branch_number(scaffold_str) + branch_numbers + + # don't capture atom mapping in the scaffold + attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str)) + if canonical: + attach_pos = sorted(attach_pos) + starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1 + for attach in attach_pos: + val = str(starting_num) if starting_num < 10 else f"%{starting_num}" + # we cannot have anything of the form "\([@=-#-$/\]*\d+\)" + attach_regexp = re.compile(r"(" + re.escape(attach) + r")") + scaffold_str = attach_regexp.sub(val, scaffold_str) + starting_num += 1 + + # now we need to remove all the parenthesis around digit only number + wrong_attach = re.compile(r"\(([\%\d]*)\)") + scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str) + # furthermore, we autoapply rdkit-compatible digit standardization. + if rdkit_safe: + pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)" + replacement = r"\g<1>\g<2>" + scaffold_str = re.sub(pattern, replacement, scaffold_str) + if not self.ignore_stereo and has_stereo_bonds and not dm.same_mol(scaffold_str, inp): + print( + "Warning: Ignoring stereo is disabled, but molecule has stereochemistry interferring with SAFE representation" + ) + return scaffold_str + + +def encode( + inp: Union[str, dm.Mol], + canonical: bool = True, + randomize: Optional[bool] = False, + seed: Optional[int] = None, + slicer: Optional[Union[List[str], str, Callable]] = None, + require_hs: Optional[bool] = None, + constraints: Optional[List[dm.Mol]] = None, + ignore_stereo: Optional[bool] = False, +): + """ + Convert input smiles to SAFE representation + + Args: + inp: input smiles + canonical: whether to return canonical SAFE string. Defaults to True + randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided + seed: optional seed to use when allowing randomization of the SAFE encoding. + slicer: slicer algorithm to use for encoding. Defaults to "brics". + require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. + constraints: List of molecules or pattern to preserve during the SAFE construction. + ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. + """ + if slicer is None: + slicer = "brics" + with dm.without_rdkit_log(): + safe_obj = SAFEConverter(slicer=slicer, require_hs=require_hs, ignore_stereo=ignore_stereo) + try: + encoded = safe_obj.encoder( + inp, + canonical=canonical, + randomize=randomize, + constraints=constraints, + seed=seed, + ) + except SAFEFragmentationError as e: + raise e + except Exception as e: + raise SAFEEncodeError(f"Failed to encode {inp} with {slicer}") from e + return encoded + + +def decode( + safe_str: str, + as_mol: bool = False, + canonical: bool = False, + fix: bool = True, + remove_added_hs: bool = True, + remove_dummies: bool = True, + ignore_errors: bool = False, +): + """Convert input SAFE representation to smiles + Args: + safe_str: input SAFE representation to decode as a valid molecule or smiles + as_mol: whether to return a molecule object or a smiles string + canonical: whether to return a canonical smiles or a randomized smiles + fix: whether to fix the SAFE representation to take into account non-connected attachment points + remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string. + remove_dummies: whether to remove dummy atoms from the SAFE representation + ignore_errors: whether to ignore error and return None on decoding failure or raise an error + + """ + with dm.without_rdkit_log(): + safe_obj = SAFEConverter() + try: + decoded = safe_obj.decoder( + safe_str, + as_mol=as_mol, + canonical=canonical, + fix=fix, + remove_dummies=remove_dummies, + remove_added_hs=remove_added_hs, + ) + + except Exception as e: + if ignore_errors: + return None + raise SAFEDecodeError(f"Failed to decode {safe_str}") from e + return decoded + +def main(): + smiles = "O=C(C#CCN1CCCCC1)Nc1ccc2ncnc(Nc3cccc(Br)c3)c2c1" + safe_string = encode(smiles) + print("SAFE representation:", safe_string) + print("SMILES representation:", decode(safe_string)) + +if __name__ == "main": + main() \ No newline at end of file diff --git a/torchtitan/utils/text_format_utils.py b/torchtitan/utils/text_format_utils.py index 65a0c15e..1953d1bf 100644 --- a/torchtitan/utils/text_format_utils.py +++ b/torchtitan/utils/text_format_utils.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/YerevaNN/ChemLactica/blob/main/chemlactica/utils/text_format_utils.py # All rights reserved +from torchtitan.utils.safe import encode SPECIAL_TAGS = { "SMILES": {"start": "[START_SMILES]", "end": "[END_SMILES]"}, @@ -85,10 +86,13 @@ def generate_formatted_string(compound_json, rng): key_value_pairs = [] key = "SMILES" value = compound_json.get(key, "") + value = encode(value) + print(value) if rng.integers(2) == 0: if value: key_value_pairs.append(format_key_value(key, value, rng)) del compound_json[key] + keys = list(compound_json.keys()) rng.shuffle(keys) From 2ca6810ad48d5cee9e2b08da858d9bb74195b055 Mon Sep 17 00:00:00 2001 From: Filya Geikyan Date: Mon, 23 Sep 2024 16:05:00 +0400 Subject: [PATCH 2/4] add configs --- torchtitan/config_manager.py | 5 +++++ torchtitan/datasets/hf_datasets.py | 7 +++++-- torchtitan/utils/dataset_utils.py | 4 ++-- torchtitan/utils/text_format_utils.py | 8 +++++--- train.py | 2 ++ train_configs/debug_model.toml | 1 + 6 files changed, 20 insertions(+), 7 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 0cb260cc..c75899af 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -254,6 +254,11 @@ def __init__(self): default=True, action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", + ) + self.parser.add_argument( + "--training.representation_type", + default="SMILES", + help="The representation type of the molecule for training the model.", ) self.parser.add_argument( "--experimental.enable_async_tensor_parallel", diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 06c9d3ea..77f0380d 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -83,6 +83,7 @@ def __init__( dataset_path: Optional[str], data_processing_style: str, tokenizer: Tokenizer, + representation_type: str = "SMILES", seq_len: int = 2048, world_size: int = 1, rank: int = 0, @@ -124,6 +125,7 @@ def __init__( self._tokenizer = tokenizer self.seq_len = seq_len self.infinite = infinite + self.representation_type = representation_type # variables for checkpointing self._sample_idx = 0 @@ -137,7 +139,7 @@ def __iter__(self): while True: for sample_json in self._get_data_iter(): - sample_text = self.data_processing_fn(sample_json, self.rng) + sample_text = self.data_processing_fn(sample_json, self.rng, self.representation_type) sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) self._all_tokens.extend(sample_tokens) self._sample_idx += 1 @@ -219,10 +221,11 @@ def build_hf_data_loader( seq_len: int, world_size, rank, + representation_type, infinite: bool = True, ): hf_ds = HuggingFaceDataset( - dataset_name, dataset_path, data_processing_style, tokenizer, seq_len, world_size, rank, infinite + dataset_name, dataset_path, data_processing_style, tokenizer, representation_type, seq_len, world_size, rank, infinite ) return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size) diff --git a/torchtitan/utils/dataset_utils.py b/torchtitan/utils/dataset_utils.py index 4f5e171e..469a01be 100644 --- a/torchtitan/utils/dataset_utils.py +++ b/torchtitan/utils/dataset_utils.py @@ -17,11 +17,11 @@ def load_jsonl_line(jsonl_line): raise ValueError(f"Error decoding JSON: {e}") -def chemlactica_style_data_processing(sample_json, rng): +def chemlactica_style_data_processing(sample_json, rng, representation_type): try: compound = delete_empty_tags(sample_json) sample_json = generate_formatted_string( - compound, rng + compound, rng, representation_type ) except Exception as e: print(e) diff --git a/torchtitan/utils/text_format_utils.py b/torchtitan/utils/text_format_utils.py index 1953d1bf..9514ecf8 100644 --- a/torchtitan/utils/text_format_utils.py +++ b/torchtitan/utils/text_format_utils.py @@ -82,12 +82,14 @@ def delete_empty_tags(compound_json): return compound_json -def generate_formatted_string(compound_json, rng): +def generate_formatted_string(compound_json, rng, representation_type = "SMILES"): key_value_pairs = [] key = "SMILES" value = compound_json.get(key, "") - value = encode(value) - print(value) + + if representation_type == "SAFE": + value = encode(value) + if rng.integers(2) == 0: if value: key_value_pairs.append(format_key_value(key, value, rng)) diff --git a/train.py b/train.py index 2842cb43..df8d7e7b 100644 --- a/train.py +++ b/train.py @@ -91,6 +91,7 @@ def main(job_config: JobConfig): tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader + representation_type = job_config.training.representation_type data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -100,6 +101,7 @@ def main(job_config: JobConfig): job_config.training.seq_len, dp_degree, dp_rank, + representation_type ) # build model (using meta init) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index fdd360d0..b80718e2 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -44,6 +44,7 @@ tensor_parallel_degree = 1 compile = false dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) data_process_style="chemlactica_style" +representation_type="SAFE" [experimental] pipeline_parallel_degree = 1 From e608677e4af435525da1c02ce6323d1ec3e0ab94 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:38:55 +0400 Subject: [PATCH 3/4] Revert "Safe" --- torchtitan/config_manager.py | 5 - torchtitan/datasets/hf_datasets.py | 4 +- torchtitan/utils/dataset_utils.py | 4 +- torchtitan/utils/safe.py | 465 -------------------------- torchtitan/utils/text_format_utils.py | 8 +- train.py | 2 - train_configs/debug_model.toml | 1 - 7 files changed, 4 insertions(+), 485 deletions(-) delete mode 100644 torchtitan/utils/safe.py diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 4aa564b6..14ef3a4e 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -270,11 +270,6 @@ def __init__(self): default=True, action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", - ) - self.parser.add_argument( - "--training.representation_type", - default="SMILES", - help="The representation type of the molecule for training the model.", ) self.parser.add_argument( "--experimental.enable_async_tensor_parallel", diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 1509695f..6840f469 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -87,7 +87,6 @@ def __init__( dataset_path: Optional[str], data_processing_style: str, tokenizer: Tokenizer, - representation_type: str = "SMILES", seq_len: int = 2048, world_size: int = 1, rank: int = 0, @@ -134,7 +133,6 @@ def __init__( self._tokenizer = tokenizer self.seq_len = seq_len self.infinite = infinite - self.representation_type = representation_type self.rank = rank self.world_size = world_size @@ -144,6 +142,7 @@ def __init__( else: self.store = None + # variables for checkpointing self._sample_idx = 0 self._all_tokens: List[int] = [] @@ -256,7 +255,6 @@ def build_hf_data_loader( seq_len: int, world_size, rank, - representation_type, infinite: bool = True, pin_memory: bool = False, num_workers: int = 2, diff --git a/torchtitan/utils/dataset_utils.py b/torchtitan/utils/dataset_utils.py index c397ee4c..40ef7aae 100644 --- a/torchtitan/utils/dataset_utils.py +++ b/torchtitan/utils/dataset_utils.py @@ -30,12 +30,12 @@ def load_jsonl_line(jsonl_line): raise ValueError(f"Error decoding JSON: {e}") -def chemlactica_style_data_processing(sample_json, rng, representation_type): +def chemlactica_style_data_processing(sample_json, rng): try: sample_json = json.loads(sample_json["text"]) compound = delete_empty_tags(sample_json) sample_json = generate_formatted_string( - compound, rng, representation_type + compound, rng ) except Exception as e: print(e) diff --git a/torchtitan/utils/safe.py b/torchtitan/utils/safe.py deleted file mode 100644 index 6f51df09..00000000 --- a/torchtitan/utils/safe.py +++ /dev/null @@ -1,465 +0,0 @@ -import itertools -import re -from collections import Counter -from contextlib import suppress -from typing import Callable, List, Optional, Union - -import datamol as dm -import numpy as np -from rdkit import Chem -from rdkit.Chem import BRICS - -class SAFEDecodeError(Exception): - """Raised when a string cannot be decoded with the given encoding.""" - pass - -class SAFEEncodeError(Exception): - """Raised when a molecule cannot be encoded using SAFE.""" - pass - - -class SAFEFragmentationError(Exception): - """Raised when a the slicing algorithm return empty bonds.""" - pass - - -class SAFEConverter: - """Molecule line notation conversion from SMILES to SAFE - - A SAFE representation is a string based representation of a molecule decomposition into fragment components, - separated by a dot ('.'). Note that each component (fragment) might not be a valid molecule by themselves, - unless explicitely correct to add missing hydrogens. - - !!! note "Slicing algorithms" - - By default SAFE strings are generated using `BRICS`, however, the following alternative are supported: - - * [Hussain-Rea (`hr`)](https://pubs.acs.org/doi/10.1021/ci900450m) - * [RECAP (`recap`)](https://pubmed.ncbi.nlm.nih.gov/9611787/) - * [RDKit's MMPA (`mmpa`)](https://www.rdkit.org/docs/source/rdkit.Chem.rdMMPA.html) - * Any possible attachment points (`attach`) - - Furthermore, you can also provide your own slicing algorithm, which should return a pair of atoms - corresponding to the bonds to break. - - """ - - SUPPORTED_SLICERS = ["hr", "rotatable", "recap", "mmpa", "attach", "brics"] - __SLICE_SMARTS = { - "hr": ["[*]!@-[*]"], # any non ring single bond - "recap": [ - "[$([C;!$(C([#7])[#7])](=!@[O]))]!@[$([#7;+0;!D1])]", - "[$(C=!@O)]!@[$([O;+0])]", - "[$([N;!D1;+0;!$(N-C=[#7,#8,#15,#16])](-!@[*]))]-!@[$([*])]", - "[$(C(=!@O)([#7;+0;D2,D3])!@[#7;+0;D2,D3])]!@[$([#7;+0;D2,D3])]", - "[$([O;+0](-!@[#6!$(C=O)])-!@[#6!$(C=O)])]-!@[$([#6!$(C=O)])]", - "C=!@C", - "[N;+1;D4]!@[#6]", - "[$([n;+0])]-!@C", - "[$([O]=[C]-@[N;+0])]-!@[$([C])]", - "c-!@c", - "[$([#7;+0;D2,D3])]-!@[$([S](=[O])=[O])]", - ], - "mmpa": ["[#6+0;!$(*=,#[!#6])]!@!=!#[*]"], # classical mmpa slicing smarts - "attach": ["[*]!@[*]"], # any potential attachment point, including hydrogens when explicit - "rotatable": ["[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"], - } - - def __init__( - self, - slicer: Optional[Union[str, List[str], Callable]] = "brics", - require_hs: Optional[bool] = None, - use_original_opener_for_attach: bool = True, - ignore_stereo: bool = False, - ): - """Constructor for the SAFE converter - - Args: - slicer: slicer algorithm to use for encoding. - Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS) - or a custom callable that returns the bond ids that can be sliced. - require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. - `attach` slicer requires adding hydrogens. - use_original_opener_for_attach: whether to use the original branch opener digit when adding back - mapping number to attachment points, or use simple enumeration. - ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. - - """ - self.slicer = slicer - if isinstance(slicer, str) and slicer.lower() in self.SUPPORTED_SLICERS: - self.slicer = self.__SLICE_SMARTS.get(slicer.lower(), slicer) - if self.slicer != "brics" and isinstance(self.slicer, str): - self.slicer = [self.slicer] - if isinstance(self.slicer, (list, tuple)): - self.slicer = [dm.from_smarts(x) for x in self.slicer] - if any(x is None for x in self.slicer): - raise ValueError(f"Slicer: {slicer} cannot be valid") - self.require_hs = require_hs or (slicer == "attach") - self.use_original_opener_for_attach = use_original_opener_for_attach - self.ignore_stereo = ignore_stereo - - @staticmethod - def randomize(mol: dm.Mol, rng: Optional[int] = None): - """Randomize the position of the atoms in a mol. - - Args: - mol: molecules to randomize - rng: optional seed to use - """ - if isinstance(rng, int): - rng = np.random.default_rng(rng) - if mol.GetNumAtoms() == 0: - return mol - atom_indices = list(range(mol.GetNumAtoms())) - atom_indices = rng.permutation(atom_indices).tolist() - return Chem.RenumberAtoms(mol, atom_indices) - - @classmethod - def _find_branch_number(cls, inp: str): - """Find the branch number and ring closure in the SMILES representation using regexp - - Args: - inp: input smiles - """ - inp = re.sub(r"\[.*?\]", "", inp) # noqa - matching_groups = re.findall(r"((?<=%)\d{2})|((? 0: - mol = Chem.FragmentOnBonds( - mol, - bonds, - dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))], - ) - # here we need to be clever and disable rooted atom as the atom with mapping - - frags = list(Chem.GetMolFrags(mol, asMols=True)) - if randomize: - frags = rng.permutation(frags).tolist() - elif canonical: - frags = sorted( - frags, - key=lambda x: x.GetNumAtoms(), - reverse=True, - ) - - frags_str = [] - for frag in frags: - non_map_atom_idxs = [ - atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0 - ] - frags_str.append( - Chem.MolToSmiles( - frag, - isomericSmiles=True, - canonical=True, # needs to always be true - rootedAtAtom=non_map_atom_idxs[0], - ) - ) - - scaffold_str = ".".join(frags_str) - # EN: fix for https://github.com/datamol-io/safe/issues/37 - # we were using the wrong branch number count which did not take into account - # possible change in digit utilization after bond slicing - scf_branch_num = self._find_branch_number(scaffold_str) + branch_numbers - - # don't capture atom mapping in the scaffold - attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str)) - if canonical: - attach_pos = sorted(attach_pos) - starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1 - for attach in attach_pos: - val = str(starting_num) if starting_num < 10 else f"%{starting_num}" - # we cannot have anything of the form "\([@=-#-$/\]*\d+\)" - attach_regexp = re.compile(r"(" + re.escape(attach) + r")") - scaffold_str = attach_regexp.sub(val, scaffold_str) - starting_num += 1 - - # now we need to remove all the parenthesis around digit only number - wrong_attach = re.compile(r"\(([\%\d]*)\)") - scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str) - # furthermore, we autoapply rdkit-compatible digit standardization. - if rdkit_safe: - pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)" - replacement = r"\g<1>\g<2>" - scaffold_str = re.sub(pattern, replacement, scaffold_str) - if not self.ignore_stereo and has_stereo_bonds and not dm.same_mol(scaffold_str, inp): - print( - "Warning: Ignoring stereo is disabled, but molecule has stereochemistry interferring with SAFE representation" - ) - return scaffold_str - - -def encode( - inp: Union[str, dm.Mol], - canonical: bool = True, - randomize: Optional[bool] = False, - seed: Optional[int] = None, - slicer: Optional[Union[List[str], str, Callable]] = None, - require_hs: Optional[bool] = None, - constraints: Optional[List[dm.Mol]] = None, - ignore_stereo: Optional[bool] = False, -): - """ - Convert input smiles to SAFE representation - - Args: - inp: input smiles - canonical: whether to return canonical SAFE string. Defaults to True - randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided - seed: optional seed to use when allowing randomization of the SAFE encoding. - slicer: slicer algorithm to use for encoding. Defaults to "brics". - require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. - constraints: List of molecules or pattern to preserve during the SAFE construction. - ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. - """ - if slicer is None: - slicer = "brics" - with dm.without_rdkit_log(): - safe_obj = SAFEConverter(slicer=slicer, require_hs=require_hs, ignore_stereo=ignore_stereo) - try: - encoded = safe_obj.encoder( - inp, - canonical=canonical, - randomize=randomize, - constraints=constraints, - seed=seed, - ) - except SAFEFragmentationError as e: - raise e - except Exception as e: - raise SAFEEncodeError(f"Failed to encode {inp} with {slicer}") from e - return encoded - - -def decode( - safe_str: str, - as_mol: bool = False, - canonical: bool = False, - fix: bool = True, - remove_added_hs: bool = True, - remove_dummies: bool = True, - ignore_errors: bool = False, -): - """Convert input SAFE representation to smiles - Args: - safe_str: input SAFE representation to decode as a valid molecule or smiles - as_mol: whether to return a molecule object or a smiles string - canonical: whether to return a canonical smiles or a randomized smiles - fix: whether to fix the SAFE representation to take into account non-connected attachment points - remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string. - remove_dummies: whether to remove dummy atoms from the SAFE representation - ignore_errors: whether to ignore error and return None on decoding failure or raise an error - - """ - with dm.without_rdkit_log(): - safe_obj = SAFEConverter() - try: - decoded = safe_obj.decoder( - safe_str, - as_mol=as_mol, - canonical=canonical, - fix=fix, - remove_dummies=remove_dummies, - remove_added_hs=remove_added_hs, - ) - - except Exception as e: - if ignore_errors: - return None - raise SAFEDecodeError(f"Failed to decode {safe_str}") from e - return decoded - -def main(): - smiles = "O=C(C#CCN1CCCCC1)Nc1ccc2ncnc(Nc3cccc(Br)c3)c2c1" - safe_string = encode(smiles) - print("SAFE representation:", safe_string) - print("SMILES representation:", decode(safe_string)) - -if __name__ == "main": - main() \ No newline at end of file diff --git a/torchtitan/utils/text_format_utils.py b/torchtitan/utils/text_format_utils.py index 9514ecf8..65a0c15e 100644 --- a/torchtitan/utils/text_format_utils.py +++ b/torchtitan/utils/text_format_utils.py @@ -1,6 +1,5 @@ # Adapted from https://github.com/YerevaNN/ChemLactica/blob/main/chemlactica/utils/text_format_utils.py # All rights reserved -from torchtitan.utils.safe import encode SPECIAL_TAGS = { "SMILES": {"start": "[START_SMILES]", "end": "[END_SMILES]"}, @@ -82,19 +81,14 @@ def delete_empty_tags(compound_json): return compound_json -def generate_formatted_string(compound_json, rng, representation_type = "SMILES"): +def generate_formatted_string(compound_json, rng): key_value_pairs = [] key = "SMILES" value = compound_json.get(key, "") - - if representation_type == "SAFE": - value = encode(value) - if rng.integers(2) == 0: if value: key_value_pairs.append(format_key_value(key, value, rng)) del compound_json[key] - keys = list(compound_json.keys()) rng.shuffle(keys) diff --git a/train.py b/train.py index 797fca61..b7ee0d23 100644 --- a/train.py +++ b/train.py @@ -93,7 +93,6 @@ def main(job_config: JobConfig): tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader - representation_type = job_config.training.representation_type data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -103,7 +102,6 @@ def main(job_config: JobConfig): job_config.training.seq_len, dp_degree, dp_rank, - representation_type, pin_memory = job_config.dataloader.pin_memory, num_workers = job_config.dataloader.num_workers, special_mode = job_config.dataloader.special_mode, diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 7aa7caa1..2829d098 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -44,7 +44,6 @@ tensor_parallel_degree = 1 compile = true dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) data_process_style="chemlactica_style" -representation_type="SAFE" [experimental] pipeline_parallel_degree = 1 From e82a73e2144fa2cec472039b368a2f78d85cd9a8 Mon Sep 17 00:00:00 2001 From: Filya Geikyan Date: Thu, 26 Sep 2024 18:54:34 +0400 Subject: [PATCH 4/4] final safe --- torchtitan/config_manager.py | 5 + torchtitan/datasets/hf_datasets.py | 8 +- torchtitan/utils/dataset_utils.py | 4 +- torchtitan/utils/safe.py | 465 ++++++++++++++++++++++++++ torchtitan/utils/text_format_utils.py | 8 +- train.py | 2 + train_configs/debug_model.toml | 1 + 7 files changed, 487 insertions(+), 6 deletions(-) create mode 100644 torchtitan/utils/safe.py diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 14ef3a4e..4aa564b6 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -270,6 +270,11 @@ def __init__(self): default=True, action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", + ) + self.parser.add_argument( + "--training.representation_type", + default="SMILES", + help="The representation type of the molecule for training the model.", ) self.parser.add_argument( "--experimental.enable_async_tensor_parallel", diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 6840f469..582503c7 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -87,6 +87,7 @@ def __init__( dataset_path: Optional[str], data_processing_style: str, tokenizer: Tokenizer, + representation_type: str = "SMILES", seq_len: int = 2048, world_size: int = 1, rank: int = 0, @@ -135,6 +136,7 @@ def __init__( self.infinite = infinite self.rank = rank self.world_size = world_size + self.representation_type = representation_type # for non sync communication between ranks if not self.infinite and store: @@ -142,7 +144,6 @@ def __init__( else: self.store = None - # variables for checkpointing self._sample_idx = 0 self._all_tokens: List[int] = [] @@ -172,7 +173,7 @@ def __iter__(self): for sample_json in self._get_data_iter(): if self._some_rank_finished(): break - sample_text = self.data_processing_fn(sample_json, self.rng) + sample_text = self.data_processing_fn(sample_json, self.rng, self.representation_type) sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) self._all_tokens.extend(sample_tokens) self._sample_idx += 1 @@ -255,6 +256,7 @@ def build_hf_data_loader( seq_len: int, world_size, rank, + representation_type, infinite: bool = True, pin_memory: bool = False, num_workers: int = 2, @@ -268,7 +270,7 @@ def build_hf_data_loader( data_completion_store = None hf_ds = HuggingFaceDataset( - dataset_name, dataset_path, data_processing_style, tokenizer, seq_len, world_size, rank, infinite, special_mode,store = data_completion_store + dataset_name, dataset_path, data_processing_style, tokenizer, representation_type, seq_len, world_size, rank, infinite, special_mode,store = data_completion_store ) return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers) diff --git a/torchtitan/utils/dataset_utils.py b/torchtitan/utils/dataset_utils.py index 40ef7aae..c397ee4c 100644 --- a/torchtitan/utils/dataset_utils.py +++ b/torchtitan/utils/dataset_utils.py @@ -30,12 +30,12 @@ def load_jsonl_line(jsonl_line): raise ValueError(f"Error decoding JSON: {e}") -def chemlactica_style_data_processing(sample_json, rng): +def chemlactica_style_data_processing(sample_json, rng, representation_type): try: sample_json = json.loads(sample_json["text"]) compound = delete_empty_tags(sample_json) sample_json = generate_formatted_string( - compound, rng + compound, rng, representation_type ) except Exception as e: print(e) diff --git a/torchtitan/utils/safe.py b/torchtitan/utils/safe.py new file mode 100644 index 00000000..6f51df09 --- /dev/null +++ b/torchtitan/utils/safe.py @@ -0,0 +1,465 @@ +import itertools +import re +from collections import Counter +from contextlib import suppress +from typing import Callable, List, Optional, Union + +import datamol as dm +import numpy as np +from rdkit import Chem +from rdkit.Chem import BRICS + +class SAFEDecodeError(Exception): + """Raised when a string cannot be decoded with the given encoding.""" + pass + +class SAFEEncodeError(Exception): + """Raised when a molecule cannot be encoded using SAFE.""" + pass + + +class SAFEFragmentationError(Exception): + """Raised when a the slicing algorithm return empty bonds.""" + pass + + +class SAFEConverter: + """Molecule line notation conversion from SMILES to SAFE + + A SAFE representation is a string based representation of a molecule decomposition into fragment components, + separated by a dot ('.'). Note that each component (fragment) might not be a valid molecule by themselves, + unless explicitely correct to add missing hydrogens. + + !!! note "Slicing algorithms" + + By default SAFE strings are generated using `BRICS`, however, the following alternative are supported: + + * [Hussain-Rea (`hr`)](https://pubs.acs.org/doi/10.1021/ci900450m) + * [RECAP (`recap`)](https://pubmed.ncbi.nlm.nih.gov/9611787/) + * [RDKit's MMPA (`mmpa`)](https://www.rdkit.org/docs/source/rdkit.Chem.rdMMPA.html) + * Any possible attachment points (`attach`) + + Furthermore, you can also provide your own slicing algorithm, which should return a pair of atoms + corresponding to the bonds to break. + + """ + + SUPPORTED_SLICERS = ["hr", "rotatable", "recap", "mmpa", "attach", "brics"] + __SLICE_SMARTS = { + "hr": ["[*]!@-[*]"], # any non ring single bond + "recap": [ + "[$([C;!$(C([#7])[#7])](=!@[O]))]!@[$([#7;+0;!D1])]", + "[$(C=!@O)]!@[$([O;+0])]", + "[$([N;!D1;+0;!$(N-C=[#7,#8,#15,#16])](-!@[*]))]-!@[$([*])]", + "[$(C(=!@O)([#7;+0;D2,D3])!@[#7;+0;D2,D3])]!@[$([#7;+0;D2,D3])]", + "[$([O;+0](-!@[#6!$(C=O)])-!@[#6!$(C=O)])]-!@[$([#6!$(C=O)])]", + "C=!@C", + "[N;+1;D4]!@[#6]", + "[$([n;+0])]-!@C", + "[$([O]=[C]-@[N;+0])]-!@[$([C])]", + "c-!@c", + "[$([#7;+0;D2,D3])]-!@[$([S](=[O])=[O])]", + ], + "mmpa": ["[#6+0;!$(*=,#[!#6])]!@!=!#[*]"], # classical mmpa slicing smarts + "attach": ["[*]!@[*]"], # any potential attachment point, including hydrogens when explicit + "rotatable": ["[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"], + } + + def __init__( + self, + slicer: Optional[Union[str, List[str], Callable]] = "brics", + require_hs: Optional[bool] = None, + use_original_opener_for_attach: bool = True, + ignore_stereo: bool = False, + ): + """Constructor for the SAFE converter + + Args: + slicer: slicer algorithm to use for encoding. + Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS) + or a custom callable that returns the bond ids that can be sliced. + require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. + `attach` slicer requires adding hydrogens. + use_original_opener_for_attach: whether to use the original branch opener digit when adding back + mapping number to attachment points, or use simple enumeration. + ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. + + """ + self.slicer = slicer + if isinstance(slicer, str) and slicer.lower() in self.SUPPORTED_SLICERS: + self.slicer = self.__SLICE_SMARTS.get(slicer.lower(), slicer) + if self.slicer != "brics" and isinstance(self.slicer, str): + self.slicer = [self.slicer] + if isinstance(self.slicer, (list, tuple)): + self.slicer = [dm.from_smarts(x) for x in self.slicer] + if any(x is None for x in self.slicer): + raise ValueError(f"Slicer: {slicer} cannot be valid") + self.require_hs = require_hs or (slicer == "attach") + self.use_original_opener_for_attach = use_original_opener_for_attach + self.ignore_stereo = ignore_stereo + + @staticmethod + def randomize(mol: dm.Mol, rng: Optional[int] = None): + """Randomize the position of the atoms in a mol. + + Args: + mol: molecules to randomize + rng: optional seed to use + """ + if isinstance(rng, int): + rng = np.random.default_rng(rng) + if mol.GetNumAtoms() == 0: + return mol + atom_indices = list(range(mol.GetNumAtoms())) + atom_indices = rng.permutation(atom_indices).tolist() + return Chem.RenumberAtoms(mol, atom_indices) + + @classmethod + def _find_branch_number(cls, inp: str): + """Find the branch number and ring closure in the SMILES representation using regexp + + Args: + inp: input smiles + """ + inp = re.sub(r"\[.*?\]", "", inp) # noqa + matching_groups = re.findall(r"((?<=%)\d{2})|((? 0: + mol = Chem.FragmentOnBonds( + mol, + bonds, + dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))], + ) + # here we need to be clever and disable rooted atom as the atom with mapping + + frags = list(Chem.GetMolFrags(mol, asMols=True)) + if randomize: + frags = rng.permutation(frags).tolist() + elif canonical: + frags = sorted( + frags, + key=lambda x: x.GetNumAtoms(), + reverse=True, + ) + + frags_str = [] + for frag in frags: + non_map_atom_idxs = [ + atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0 + ] + frags_str.append( + Chem.MolToSmiles( + frag, + isomericSmiles=True, + canonical=True, # needs to always be true + rootedAtAtom=non_map_atom_idxs[0], + ) + ) + + scaffold_str = ".".join(frags_str) + # EN: fix for https://github.com/datamol-io/safe/issues/37 + # we were using the wrong branch number count which did not take into account + # possible change in digit utilization after bond slicing + scf_branch_num = self._find_branch_number(scaffold_str) + branch_numbers + + # don't capture atom mapping in the scaffold + attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str)) + if canonical: + attach_pos = sorted(attach_pos) + starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1 + for attach in attach_pos: + val = str(starting_num) if starting_num < 10 else f"%{starting_num}" + # we cannot have anything of the form "\([@=-#-$/\]*\d+\)" + attach_regexp = re.compile(r"(" + re.escape(attach) + r")") + scaffold_str = attach_regexp.sub(val, scaffold_str) + starting_num += 1 + + # now we need to remove all the parenthesis around digit only number + wrong_attach = re.compile(r"\(([\%\d]*)\)") + scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str) + # furthermore, we autoapply rdkit-compatible digit standardization. + if rdkit_safe: + pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)" + replacement = r"\g<1>\g<2>" + scaffold_str = re.sub(pattern, replacement, scaffold_str) + if not self.ignore_stereo and has_stereo_bonds and not dm.same_mol(scaffold_str, inp): + print( + "Warning: Ignoring stereo is disabled, but molecule has stereochemistry interferring with SAFE representation" + ) + return scaffold_str + + +def encode( + inp: Union[str, dm.Mol], + canonical: bool = True, + randomize: Optional[bool] = False, + seed: Optional[int] = None, + slicer: Optional[Union[List[str], str, Callable]] = None, + require_hs: Optional[bool] = None, + constraints: Optional[List[dm.Mol]] = None, + ignore_stereo: Optional[bool] = False, +): + """ + Convert input smiles to SAFE representation + + Args: + inp: input smiles + canonical: whether to return canonical SAFE string. Defaults to True + randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided + seed: optional seed to use when allowing randomization of the SAFE encoding. + slicer: slicer algorithm to use for encoding. Defaults to "brics". + require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. + constraints: List of molecules or pattern to preserve during the SAFE construction. + ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. + """ + if slicer is None: + slicer = "brics" + with dm.without_rdkit_log(): + safe_obj = SAFEConverter(slicer=slicer, require_hs=require_hs, ignore_stereo=ignore_stereo) + try: + encoded = safe_obj.encoder( + inp, + canonical=canonical, + randomize=randomize, + constraints=constraints, + seed=seed, + ) + except SAFEFragmentationError as e: + raise e + except Exception as e: + raise SAFEEncodeError(f"Failed to encode {inp} with {slicer}") from e + return encoded + + +def decode( + safe_str: str, + as_mol: bool = False, + canonical: bool = False, + fix: bool = True, + remove_added_hs: bool = True, + remove_dummies: bool = True, + ignore_errors: bool = False, +): + """Convert input SAFE representation to smiles + Args: + safe_str: input SAFE representation to decode as a valid molecule or smiles + as_mol: whether to return a molecule object or a smiles string + canonical: whether to return a canonical smiles or a randomized smiles + fix: whether to fix the SAFE representation to take into account non-connected attachment points + remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string. + remove_dummies: whether to remove dummy atoms from the SAFE representation + ignore_errors: whether to ignore error and return None on decoding failure or raise an error + + """ + with dm.without_rdkit_log(): + safe_obj = SAFEConverter() + try: + decoded = safe_obj.decoder( + safe_str, + as_mol=as_mol, + canonical=canonical, + fix=fix, + remove_dummies=remove_dummies, + remove_added_hs=remove_added_hs, + ) + + except Exception as e: + if ignore_errors: + return None + raise SAFEDecodeError(f"Failed to decode {safe_str}") from e + return decoded + +def main(): + smiles = "O=C(C#CCN1CCCCC1)Nc1ccc2ncnc(Nc3cccc(Br)c3)c2c1" + safe_string = encode(smiles) + print("SAFE representation:", safe_string) + print("SMILES representation:", decode(safe_string)) + +if __name__ == "main": + main() \ No newline at end of file diff --git a/torchtitan/utils/text_format_utils.py b/torchtitan/utils/text_format_utils.py index 65a0c15e..9514ecf8 100644 --- a/torchtitan/utils/text_format_utils.py +++ b/torchtitan/utils/text_format_utils.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/YerevaNN/ChemLactica/blob/main/chemlactica/utils/text_format_utils.py # All rights reserved +from torchtitan.utils.safe import encode SPECIAL_TAGS = { "SMILES": {"start": "[START_SMILES]", "end": "[END_SMILES]"}, @@ -81,14 +82,19 @@ def delete_empty_tags(compound_json): return compound_json -def generate_formatted_string(compound_json, rng): +def generate_formatted_string(compound_json, rng, representation_type = "SMILES"): key_value_pairs = [] key = "SMILES" value = compound_json.get(key, "") + + if representation_type == "SAFE": + value = encode(value) + if rng.integers(2) == 0: if value: key_value_pairs.append(format_key_value(key, value, rng)) del compound_json[key] + keys = list(compound_json.keys()) rng.shuffle(keys) diff --git a/train.py b/train.py index b7ee0d23..797fca61 100644 --- a/train.py +++ b/train.py @@ -93,6 +93,7 @@ def main(job_config: JobConfig): tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader + representation_type = job_config.training.representation_type data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -102,6 +103,7 @@ def main(job_config: JobConfig): job_config.training.seq_len, dp_degree, dp_rank, + representation_type, pin_memory = job_config.dataloader.pin_memory, num_workers = job_config.dataloader.num_workers, special_mode = job_config.dataloader.special_mode, diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 2829d098..7aa7caa1 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -44,6 +44,7 @@ tensor_parallel_degree = 1 compile = true dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K) data_process_style="chemlactica_style" +representation_type="SAFE" [experimental] pipeline_parallel_degree = 1