@@ -587,7 +587,8 @@ def __init__(self, PMGen_pdb, output_dir,
587587 only_pseudo_sequence_design = True , anchor_pred = True ,
588588 sampling_temp = 5 , batch_size = 1 , hot_spot_thr = 6.0 ,
589589 save_hotspots = True , binder_pred = False , fix_anchors = False ,
590- anchor_and_peptide = None , return_match_allele = False ):
590+ anchor_and_peptide = None , return_match_allele = False ,
591+ model_name = 'v_48_020' ):
591592 '''
592593 Args:
593594 PMGen_pdb: (str) Single Chain pdb path generated by PMGen AFfine.
@@ -609,6 +610,7 @@ def __init__(self, PMGen_pdb, output_dir,
609610 Only used if fix_anchors==True, Default: False
610611 return_match_allele: (bool) Returns a list of one or two elementsself.matched_alleles if binder_pred is True,
611612 Depending on MHC type.
613+ model_name (string): model name of proteinmpnn, allowed models found in ProteinMPNN/vanilla_model_weights
612614 '''
613615 self .pdb = PMGen_pdb
614616 self .output_dir = output_dir
@@ -627,6 +629,7 @@ def __init__(self, PMGen_pdb, output_dir,
627629 self .fix_anchors = fix_anchors
628630 self .anchor_and_peptide = anchor_and_peptide
629631 self .return_match_allele = return_match_allele
632+ self .model_name = model_name
630633 self .input_assertion ()
631634
632635 os .makedirs (self .output_dir , exist_ok = True )
@@ -684,7 +687,8 @@ def __mhc_design(self):
684687 "--seed" , "37" ,
685688 "--batch_size" , f'{ self .batch_size } ' ,
686689 "--save_probs" , "1" ,
687- "--save_score" , "1"
690+ "--save_score" , "1" ,
691+ "--model_name" , f"{ self .model_name } "
688692 ], check = True )
689693 print ('Full MHC Sequence Generation Mode Done! *****\n ' )
690694
@@ -721,6 +725,8 @@ def __peptide_design(self):
721725 "--save_probs" , "1" ,
722726 "--save_score" , "1" ,
723727 "--omit_AAs" , "X" ,
728+ "--path_to_model_weights" , "ProteinMPNN/vanilla_model_weights" ,
729+ "--model_name" , f"{ self .model_name } "
724730 ]
725731 if self .fix_anchors :# to fix anchors, fixed_pdbs file and design_only_positions should be generated
726732 # we have anchors, we need to define designable positions which are non-anchor positions
@@ -811,7 +817,8 @@ def __only_pseudo_sequence_design(self):
811817 "--seed" , "37" ,
812818 "--batch_size" , f'{ self .batch_size } ' ,
813819 "--save_probs" , "1" ,
814- "--save_score" , "1"
820+ "--save_score" , "1" ,
821+ "--model_name" , f"{ self .model_name } "
815822 ], check = True )
816823 print ('MHC Pseudo Sequence Generation Mode Done! *****\n ' )
817824
@@ -857,7 +864,8 @@ def run_single_proteinmpnn(path, directory, args, anchor_and_peptide=None):
857864 hot_spot_thr = args .hot_spot_thr ,
858865 binder_pred = args .binder_pred ,
859866 fix_anchors = args .fix_anchors ,
860- anchor_and_peptide = anchor_and_peptide
867+ anchor_and_peptide = anchor_and_peptide ,
868+ model_name = args .proteinmpnn_model_name
861869 )
862870 runner_mpnn .run () #
863871
@@ -978,25 +986,45 @@ def _process_row(self, row):
978986 assert len (mhc_seq_list ) == 1 , (f'mhc_seq for mhc_type==1, should be string with no "/", '
979987 f'found: \n { str (row .mhc_seq )} ' )
980988 parallel = True if self .args .run == 'parallel' else False
981- netmhc_df = run_and_parse_netmhcpan (peptide_fasta_file , mhc_type , self .tmp , mhc_seq_list , verbose = self .args .verbose , outfilename = str (row .id ), n_jobs = self .args .max_cores , parallel = parallel )
982- seen_cores = []
983989 results = {'anchors' : [], 'mhc_seqs' : [], 'ids' : [], 'peptides' : [], 'mhc_types' : []}
984- counter = 0
985- for j , net_row in netmhc_df .iterrows ():
986- peptide2 = str (net_row ['Core' ])
987- peptide1 = str (row .peptide )
988- predicted_anchors , pept1 , pept2 = processing_functions .align_and_find_anchors_mhc (peptide1 , peptide2 ,
989- mhc_type )
990- if not predicted_anchors in seen_cores :
991- seen_cores .append (predicted_anchors )
992- results ['anchors' ].append (";" .join ([str (pp ) for pp in predicted_anchors ]))
993- results ['mhc_seqs' ].append (str (row ['mhc_seq' ]))
990+ try :
991+ netmhc_df = run_and_parse_netmhcpan (peptide_fasta_file , mhc_type , self .tmp , mhc_seq_list , verbose = self .args .verbose , outfilename = str (row .id ), n_jobs = 1 , parallel = False )
992+ seen_cores = []
993+ counter = 0
994+ for j , net_row in netmhc_df .iterrows ():
995+ peptide2 = str (net_row ['Core' ])
996+ peptide1 = str (row .peptide )
997+ predicted_anchors , pept1 , pept2 = processing_functions .align_and_find_anchors_mhc (peptide1 , peptide2 ,
998+ mhc_type )
999+ if not predicted_anchors in seen_cores :
1000+ seen_cores .append (predicted_anchors )
1001+ results ['anchors' ].append (";" .join ([str (pp ) for pp in predicted_anchors ]))
1002+ results ['mhc_seqs' ].append (str (row ['mhc_seq' ]))
1003+ results ['ids' ].append (str (row ['id' ]) + '_' + str (counter ))
1004+ results ['peptides' ].append (str (row ['peptide' ]))
1005+ results ['mhc_types' ].append (int (row ['mhc_type' ]))
1006+ counter += 1
1007+ if counter == self .args .top_k : break
1008+ return results
1009+ except Exception as e : # if netmhcpan fails or does not exist, take all possible anchors
1010+ peptide = str (row .peptide )
1011+ mhc_type = int (row .mhc_type )
1012+ pep_len = len (peptide )
1013+ anchor_combinations = []
1014+ if mhc_type == 1 :
1015+ anchor_combinations = processing_functions .anchor_combinations_mhc1 (pep_len )
1016+ elif mhc_type == 2 :
1017+ anchor_combinations = processing_functions .anchor_combinations_mhc2 (pep_len )
1018+ assert len (anchor_combinations ) > 0 , f'no anchor combination is found for { row .id } . The peptide sequence "{ row .peptide } " length should be longer than 9 for mhc2 and 8 for mhc1'
1019+ for counter , anchors in enumerate (anchor_combinations ):
1020+ results ['anchors' ].append (';' .join ([str (i ) for i in anchors ]))
1021+ results ['mhc_seqs' ].append (str (row .mhc_seq ))
9941022 results ['ids' ].append (str (row ['id' ]) + '_' + str (counter ))
9951023 results ['peptides' ].append (str (row ['peptide' ]))
9961024 results ['mhc_types' ].append (int (row ['mhc_type' ]))
997- counter += 1
998- if counter == self . args . top_k : break
999- return results
1025+ return results
1026+
1027+
10001028
10011029 def process (self ):
10021030 """
@@ -1006,7 +1034,7 @@ def process(self):
10061034 DataFrame with processed results
10071035 """
10081036 df = pd .read_csv (self .args .df , sep = '\t ' )
1009- print (f" Starting Multiple Anchor Mode on { self .args .max_cores } cores. Make Sure NetMHCpan is installed" )
1037+ print (f" Starting Multiple Anchor Mode on { self .args .max_cores } cores. I netMHCpan is installed it is used, if not, all anchor combinations are processed " )
10101038 # Determine number of processes
10111039 num_processes = min (cpu_count (), int (self .args .max_cores ))
10121040 # Create multiprocessing pool
0 commit comments