Skip to content
Snippets Groups Projects
Commit 0d1dffe7 authored by zhaojichen's avatar zhaojichen
Browse files

add pafnucy model

parent e06d93b3
No related branches found
No related tags found
No related merge requests found
Showing
with 2460 additions and 0 deletions
# Contents
- [Contents](#contents)
- [Pafnucy Description](#pafnucy-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Export Process](#export-process)
- [Export](#export)
- [Inference Process](#inference-process)
- [Inference](#inference)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Pafnucy train on pdbbind v2016](#pafnucy-train-on-pdbbind-v2016)
- [Inference Performance](#inference-performance)
- [Pafnucy infer on PDBBindv2016](#pafnucy-infer-on-pdbbind-v2016)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Pafnucy Description](#contents)
Pafnucy is a deep convolutional neural network that predicts binding affinity for protein-ligand complexes. The complex is represented with a 3D grid, and the model utilizes a 3D convolution to produce a feature map of this representation, treating the atoms of both proteins and
ligands in the same manner. The model discovers patterns that are encoded by the filters in the convolutional layer and creates a feature map with spatial occurrences for each pattern in the data.
[Paper](https://watermark.silverchair.com/bty374.pdf?token=AQECAHi208BE49Ooan9kkhW_Ercy7Dm3ZL_9Cf3qfKAc485ysgAAAt8wggLbBgkqhkiG9w0BBwagggLMMIICyAIBADCCAsEGCSqGSIb3DQEHATAeBglghkgBZQMEAS4wEQQMlQsvKbTH5dfxUn0PAgEQgIICkkrW7CAf1uMTk_v4Y7Q1Ye9-sBPbjrZSRUKkj-Rs5pWdDx9rrD9NwkaXZ886O2JNA6G8-WG1R74vcz6yQLolrf02TcbdPJY4LdglhqwDZAZEjeMb-TMPtyn_G9a0YSO7Z6LBibfO9FScM3X9VPP2pA9_Qo4Yz70idqciaP-rssYZm5xKyatns6mGyJUDl-H1kzgmaZYyrTL1K68Aic787un4r5GeaqmJDz3HTqlM8RAJRSa78FfouHfiWUNF6W0pGpV0NhR0mgjH5AQEfMYzY2M9tNlRz5fGvuZqdz4rljk0AMcSRWaIRzdP_MRBoWKTFoAkWOAEJmfs6Ql5gLYXthzdTGiJDUoMXWrr3EJ3xg9NNQoxQjuxWVIOurqTAf7Wy6l6KAaKTsWQ_ldzOUu5l2vsOqNz0VtLdlwjQ64RC9_6x8m7u_4Txk2UuGoWhxRWExaYyZq-5DHU_OWmvWLFRQymOFOKCPBzSDF9l7-yEDLRpmUcKsvJmrHpmzSpYZ5iz3aeNMEbk1W8OVDmRrmJ3pOmgmIKe_03BcM2Dc5db37RTinw0FUpUitsNC3R6tl2z2xUQGxNixoaBgmZ4Avcuo74SjUYKPpkSV5sKTO7X3sWoRT0qMBdwgTqOun__NRlo4ynPk0oXICGZNSlsvn69MBSRDsNkBiJDetoslVjj5YytlAMGeuYnRM4H8_dLHCdsTcwrzb6gEaNth3d8zwc06pQtjd6JLCnHUu3xMK8WgmuP1hoKKkqrf0PSUB4XHTrt7mHUMSPkxUcfa5VcRosUGS893wBeekVsBolrckUu7blZDQHgA8KlLHOF14vj0dV_spL1pUgNnEU2SUlF-6BoKpoRYMtUfXQ-idVRYj12SO_T1E): Marta M. Stepniewska-Dziubinska, Piotr Zielenkiewicz and, Pawel Siedlecki. "Development and evaluation of a deep learning model for protein–ligand binding affinity prediction", published in Bioinformatics.
# [Model Architecture](#contents)
Pafnucy model consists of two parts: the convolutional and dense parts, with different types of connections between layers. The input is processed by a block of 3D convolutional layers combined with a max pooling layer. Pafnucy uses 3 convolutional layers with 64, 128 and 256 filters, and is followed by a max pooling layer. The result of the last convolutional layer is flattened and used as input for a block of fully connected layers.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [PDBbind v2016](http://www.pdbbind.org.cn/download.php)
- Dataset size:2.82G.
- Protein-ligand complex(general set minus refined set): includes 9228 complexes in total
- Protein-ligand complex(refined set): includes 4057 complexes in total
- Ligand molecules in general set(Mol2 format)
- Ligand molecules in general set(SDF format)
# [Features](#contents)
## Mixed Precision
The [mixed precision](https://www.mindspore.cn/tutorials/experts/en/master/others/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
# [Environment Requirements](#contents)
- Hardware(Ascend)
- Prepare hardware environment with Ascend/GPU/CPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- dataset preprocess
```python
#before prepare the data, please check you have install openbabel. The version of openbable we used is v2.4.1
#first, you can run the following example, to preprocess pdbbind dataset
python process_pdbbind_data.py --data_path pdbbind/dataset/path/
#then, use split_dataset.py script to split it into 3 subsets
python src/split_dataset.py -i processed/data/path -o output/path
#finally, convert to mindrecord
python src/create_mindrecord.py --data_path ~/splited_data/path --mindrecord_path output/path
```
- running on Ascend
```yaml
# Add data set path
mindrecord_path:/home/DataSet/mindrecord_path/
# Add checkpoint path parameters before inference
chcekpoint_path:/home/model/pafnucy/ckpt/pafnucy.ckpt
```
```python
# run training example
python train.py > train.log 2>&1 &
# run distributed training example
bash run_distribution_train.sh [MINDRECORD_PATH] [RANKTABLE_PATH] [DEVICE_NUM]
# example: bash run_distribution_train.sh ~/mindrecord_path/ ~/hccl_8p.json 8
#run standalone training example
bash run_standalone_train.sh [MINDRECORD_PATH] [DEVICE_ID]
# run evaluation example
python eval.py > eval.log 2>&1 &
OR
bash run_eval.sh [MINDRECORD_PATH] [CKPT_PATH] [DEVICE_ID]
# example: bash run_eval.sh ~/mindrecord_path/ ~/pafnucy.ckpt 1
#predict process, which can be used to score molecular complexes. As input it takes 3D grids, with each grid point described with 19 atomic features.
#first, you can create grids from molecular structures using following command.
python prepare.py -l ligand1.mol2 -p pocket.mol2 -o complexes.hdf
#then, when complexes are prepared, you can score them with the trained network.
python predict.py --predict_input /path/to/complexes.hdf --ckpt_file /path/to/pafnucy.ckpt
```
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below:
<https://gitee.com/mindspore/models/tree/master/utils/hccl_tools>.
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```text
├── model_zoo
├── README.md // descriptions about all the models
├── pafnucy
├── README.md // descriptions about pafnucy
├── scripts
│ ├──run_distribution_train.sh // shell script for distributed on Ascend
│ ├──run_standalone_train.sh // shell script for standalone on Ascend
│ ├──run_eval.sh // shell script for evaluation on Ascend
├── src
│ ├──data.py // creating dataset
│ ├──dataloader.py // creating dataset
│ ├──logger.py // logger module
│ ├──net.py // pafnucy architecture
│ ├──split_dataset.py // split dataset
├── train.py // training script
├── process_pdbbind_data.py // dataset preprocess
├── prepare.py // prepare complexes
├── predict.py // score complexes
├── eval.py // evaluation script
├── default_config.yaml // config file
├── export.py // export checkpoint files into air/mindir
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- config for Pafnucy
```python
'grid_spacing': 1.0 # distance between grid points
'lr': 1e-5 # learning rate
'momentum': 0.9 # momentum
'weight_decay': 0.001 # weight decay
'epoch_size': 20 # epoch number
'batch_size': 20 # batch size
'max_dist': 10.0 # max distance from complex center
'conv_patch': 5 # kernel size for convolutional layers
'pool_patch': 2 # kernel size for pooling layers
'conv_channels': [64, 128, 256] # number of fileters in convolutional layers
'dense_sizes': [1000, 500, 200] # number of neurons in dense layers
'keep_prob': 0.5 # dropout rate
'rotations': 24 # rotations to perform
```
For more configuration details, please refer the script `default_config.yaml`.
## [Training Process](#contents)
### Training
- running on Ascend
```python
python train.py > train.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `train.log`.
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
```text
# grep "loss is " train.log
epoch[0], train error: [0.0004304781323298812], Validation error: [2.4519902387857444], Validation RMSE: [1.5658832136483694]
epoch[1], train error: [0.001452913973480463], Validation error: [2.4301812992095946], Validation RMSE: [1.5589038774759638]
...
```
The model checkpoint will be saved in the current directory.
### Distributed Training
- running on Ascend
```bash
bash run_distribution_train.sh [MINDRECORD_PATH] [RANKTABLE_PATH] [DEVICE_NUM]
```
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/year-mouth-day-time-*-rank-[rank_id].log`. The loss value will be achieved as follows:
```text
# grep "result: " train_parallel*/year-mouth-day-time-*-rank-[rank_id].log
train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931
train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874
...
train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025
train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336
...
...
```
## [Evaluation Process](#contents)
### Evaluation
- evaluation on pdbbind validation dataset when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/pafnucy/pafnucy.ckpt".
```python
python eval.py > eval.log 2>&1 &
OR
bash run_eval.sh ~/mindrecord_path/ ~/pafnucy.ckpt 1
```
The above python command will run in the background. You can view the results through the file "year-mouth-day-time-*-rank-[rank_id].log". The accuracy of the test dataset will be as follows:
```text
# grep "Validation RMSE: " year-mouth-day-time-*-rank-[rank_id].log
2022-07-06 11:53:08,535:INFO:Validation RMSE: [1.4378043893221764]
```
Note that for evaluation after distributed training, please set the checkpoint_path to be the last saved checkpoint file such as "scripts/train_parallel0/pafnucy.ckpt". The accuracy of the test dataset will be as follows:
```text
# grep "Validation RMSE: " eval/year-mouth-day-time-*-rank-[rank_id].log
2022-07-06 11:53:08,535:INFO:Validation RMSE: [1.4378043893221764]
```
## [Export Process](#contents)
### [Export](#content)
Before export model, you must modify the config file, Cifar10 config file is cifar10_config.yaml and imagenet config file is imagenet_config.yaml.
The config items you should modify are batch_size and ckpt_file.
```shell
python export.py --ckpt_file path/to/checkpoint --file_format file_format
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
#### Pafnucy train on pdbbind v2016
| Parameters | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | Pafnucy |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
| uploaded Date | 07/06/2022 (month/day/year) |
| MindSpore Version | 1.6.0 |
| Dataset | PDBBind v2016 |
| Training Parameters | epoch=20, batch_size = 20, lr=1e-5 |
| Optimizer | Adam |
| Loss Function | MSELoss |
| Loss | 0.0016 |
| Speed | 1pc: 13 ms/step; 8pcs: 12 ms/step |
| Total time | 1pc: 99 mins; 8pcs: 25 mins |
| Parameters (M) | 13.0 |
| Checkpoint for Fine tuning | 147M (.ckpt file) |
| Model for inference | 49M (.mindir file), 49M(.air file) |
| Scripts | [Pafnucy script](https://gitee.com/mindspore/models/tree/master/research/hpc/pafnucy) |
### Inference Performance
#### Pafnucy infer on PDBBindv2016
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | Pafnucy |
| Resource | Ascend 910; OS Euler2.8 |
| Uploaded Date | 07/06/2022 (month/day/year) |
| MindSpore Version | 1.6.1 |
| Dataset | PDBBind v2016 |
| batch_size | 20 |
| outputs | probability |
| Accuracy | 1pcs: 1.44 |
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/models).
[]: #contents
[]: #Pafnucy-description
\ No newline at end of file
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/data/pafnucy/tests/data/dataset"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
mindrecord_path: ""
device_target: Ascend
enable_profiling: False
distribute: False
ckpt_path: '/cache/train/'
ckpt_file: './scripts/train/ckpt/pafnucy.ckpt'
# ==============================================================================
# Training options
grid_spacing: 1.0
lr: 1e-5
momentum: 0.9
weight_decay: 0.001
epoch_size: 20
batch_size: 20
max_dist: 10.0
conv_patch: 5
pool_patch: 2
conv_channels: [64, 128, 256]
dense_sizes: [1000, 500, 200]
keep_prob: 0.5
isize: 21
lmbda: 0.001
rotations: 24
air_name: "pafnucy"
device_id: 5
log_interval: 1
file_name: "pafnucy"
file_format: 'MINDIR'
dataset_sink_mode: True
save_checkpoint: True
save_checkpoint_epochs: 2
# acc calculation
result_path: ''
img_path: ''
#Testing options
hdf_file: './complexes.hdf'
charge_scaler: 0.425896
verbose: True
#Prepare molecular data for the network
ligands: ["/data/complexes/10gs/10gs_ligand.mol2"]
pockets: ["/data/complexes/10gs/10gs_pocket.mol2"]
ligand_format: "mol2"
pocket_format: "mol2"
output: "complexes.hdf"
mode: 'w'
affinities: ""
pre_output: "./predictions.csv"
predict_input: "./complexes.hdf"
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
file_name: 'output file name.'
file_format: 'file format'
result_path: "result files path."
img_path: "image file path."
grid_spacing: "distance between grid points"
max_dist: "max distance from complex center"
conv_patch: "kernel size for convolutional layers"
pool_patch: "kernel size for pooling layers"
conv_channel: "number of fileters in convolutional layers"
dense_size: "number of neurons in dense layers"
rotations: "rotations to perform"
charge_scaler: "scaling factor for the charge (use the same factor when preparing data for training and for predictions)"
verbose: "whether to print messages."
ligand: "files with ligands structures"
pocket: "file with pockers structures"
ligand_format: "file format for the ligand, must be supported by openbabel"
pocket_format: "file format for the pocker must be supported by openbabel"
output: "name for the file with the prepared structures"
mode: "mode for the output fiel"
affinities: "CSV table with affinity values, it must contain two columns:name which must be equal to ligand's
file name without extension and affinity which must contain floats"
---
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluation"""
import os
import numpy as np
import pandas as pd
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.net import SBNetWork
from src.logger import get_logger
from src.model_utils.config import config
from src.dataloader import minddataset_loader_val
from src.model_utils.moxing_adapter import moxing_wrapper
def modelarts_pre_process():
pass
def get_data_size(csv_path):
csv_path = os.path.join(csv_path, './ds_size.csv')
result = pd.read_csv(csv_path)
val_size = result['size'][1]
return val_size
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=config.device_id)
network = SBNetWork(in_chanel=[19, 64, 128],
out_chanle=config.conv_channels,
dense_size=config.dense_sizes,
lmbda=config.lmbda,
isize=config.isize, keep_prob=1.0)
network.set_train(False)
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(network, param_dict)
val_size = get_data_size(config.mindrecord_path)
val_path = './val/validation_dataset.mindrecord'
val_rotation_data, val_rot_weight = minddataset_loader_val(configs=config,
mindfile=os.path.join(config.mindrecord_path, val_path),
no_batch_size=val_size)
val_data_loader = val_rotation_data.create_dict_iterator()
val_data_size = val_rotation_data.get_dataset_size()
config.logger.info("Finish Load dataset and Network. dataset size: validation %d", val_data_size)
final_mse_v = 0
for _, vdata in enumerate(val_data_loader):
coord_features = vdata['coords_features']
affinity = vdata['affinitys']
mse_v = network(coord_features, affinity)
temp_mse_v = mse_v.asnumpy() * val_rot_weight
final_mse_v += temp_mse_v
config.logger.info('Validation RMSE: [%.2f]', np.sqrt(final_mse_v))
if __name__ == '__main__':
config.logger = get_logger('./', config.device_id)
config.logger.save_args(config)
run_eval()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export"""
import numpy as np
import mindspore as ms
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.net import SBNetWork
from src.model_utils.config import config
def run_export():
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=5)
net = SBNetWork(in_chanel=[19, 64, 128],
out_chanle=config.conv_channels,
dense_size=config.dense_sizes,
osize=1, lmbda=0.001,
isize=21, keep_prob=1.0)
assert config.ckpt_file is not None, "config.ckpt_file is None."
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(net, param_dict)
coor = Tensor(np.ones([1, 19, 21, 21, 21]), ms.float32)
affine = Tensor(np.ones([1, 1]), ms.float32)
inputs = [coor, affine]
export(net, *inputs, file_name=config.file_name, file_format=config.file_format)
if __name__ == '__main__':
run_export()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""score complexes"""
import os
import numpy as np
import pandas as pd
import h5py
from mindspore import context
from mindspore.common.tensor import Tensor
import mindspore.dataset as ds
from mindspore.common import dtype as mstype
import mindspore.dataset.transforms.c_transforms as C
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.net import SBNetWork
from src.logger import get_logger
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.data import Featurizer, make_grid
def modelarts_pre_process():
pass
def input_file(in_paths):
"Check if input file exists"
in_paths = os.path.abspath(in_paths)
if not os.path.exists(in_paths):
raise IOError('File %s does not exist.' % in_paths)
return in_paths
def eval_dataset(configs, data_path):
"""eval dataset process"""
featurizer = Featurizer()
charge_column = featurizer.FEATURE_NAMES.index('partialcharge')
coords = []
features = []
names = []
with h5py.File(data_path, 'r') as f:
for name in f:
names.append(name)
dataset = f[name]
coords.append(dataset[:, :3])
features.append(dataset[:, 3:])
if configs.verbose:
if configs.batch_size == 0:
configs.logger.info('Predict for all complexes at once')
else:
configs.logger.info('%s samples per batch' % configs.batch_size)
evaldata = []
for crd, f in zip(coords, features):
evaldata.append(make_grid(crd, f, max_dist=configs.max_dist,
grid_resolution=configs.grid_spacing))
batch_grid = np.vstack(evaldata)
batch_grid[..., charge_column] /= configs.charge_scaler
batch_grid = np.transpose(batch_grid, axes=(0, 4, 1, 2, 3))
batch_grid = np.expand_dims(batch_grid, axis=1)
print("batch grid: ", batch_grid.shape, names)
return batch_grid, names
class EvalDatasetIter:
"""Evaluation dataset iterator"""
def __init__(self, grids):
self.grids = grids
def __getitem__(self, index):
return self.grids[index]
def __len__(self):
return len(self.grids)
def load_evaldata(configs, data_path):
"""dataset loader"""
batch_grid, names = eval_dataset(configs, data_path)
eval_data = EvalDatasetIter(batch_grid)
eval_loader = ds.GeneratorDataset(eval_data, column_names=['grid'])
type_cast_op = C.TypeCast(mstype.float32)
eval_loader = eval_loader.map(type_cast_op, input_columns=['grid'])
eval_loader = eval_loader.batch(batch_size=20)
return eval_loader, names
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=5)
network = SBNetWork(in_chanel=[19, 64, 128],
out_chanle=config.conv_channels,
dense_size=config.dense_sizes,
lmbda=config.lmbda,
isize=config.isize, keep_prob=1.0, is_training=False)
network.set_train(False)
hdf_file_path = input_file(config.hdf_file)
data_loader, names = load_evaldata(config, hdf_file_path)
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(network, param_dict)
prediction = []
for data in data_loader.create_dict_iterator(output_numpy=True):
grids = data['grid']
if len(grids.shape) == 5:
grids = Tensor(grids)
elif len(grids.shape) == 6:
grids = Tensor(np.squeeze(grids, 1))
else:
config.logger.info("Wrong input shape, please check dataset preprocess.")
preds = network(grids)
prediction.append(preds.asnumpy())
config.logger.info("Finishing Evaluate.......")
results = pd.DataFrame({'name': names, 'prediction': np.vstack(prediction).flatten()})
results.to_csv(config.pre_output, index=False)
config.logger.info('Result saved to %s', config.pre_output)
if __name__ == '__main__':
config.logger = get_logger('./', config.device_id)
config.logger.save_args(config)
path = config.predict_input
paths = input_file(path)
config.hdf_file = paths
run_eval()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""prepare complex for predict"""
import os
import ast
import numpy as np
import pandas as pd
import h5py
import pybel
from src.data import Featurizer
def get_pocket(configs, num_pockets, featurizer, num_ligands):
if num_pockets > 1:
for pocket_file in configs.pocket:
if configs.verbose:
print('reading %s' % pocket_file)
try:
pocket = next(pybel.readfile(configs.pocket_format, pocket_file))
except:
raise IOError('Cannot read %s file' % pocket_file)
pocket_coords, pocket_features = featurizer.get_features(pocket, molcode=-1)
yield (pocket_coords, pocket_features)
else:
pocket_file = configs.pocket[0]
try:
pocket = next(pybel.readfile(configs.pocket_format, pocket_file))
except:
raise IOError('Cannot read %s file' % pocket_file)
pocket_coords, pocket_features = featurizer.get_features(pocket, molcode=-1)
for _ in range(num_ligands):
yield (pocket_coords, pocket_features)
def input_file(path):
"""Check if input file exists."""
path = os.path.abspath(path)
if not os.path.exists(path):
raise IOError('File %s does not exist.' % path)
return path
def output_file(path):
"""Check if output file can be created."""
path = os.path.abspath(path)
dirname = os.path.dirname(path)
if not os.access(dirname, os.W_OK):
raise IOError('File %s cannot be created (check your permissions).'
% path)
return path
def prepare(configs):
num_pockets = len(configs.pocket)
num_ligands = len(configs.ligand)
featurizer = Featurizer()
if num_pockets not in (1, num_ligands):
raise IOError('%s pockets specified for %s ligands. You must either provide '
'a single pocket or a separate pocket for each ligand' % (num_pockets, num_ligands))
if configs.verbose:
print('%s ligands and %s pockets to prepare:' % (num_ligands, num_pockets))
if num_pockets == 1:
print(' pocket: %s' % configs.pocket[0])
for ligand_file in configs.ligand:
print(' ligand: %s' % ligand_file)
else:
for ligand_file, pocket_file in zip(configs.ligand, configs.pocket):
print(' ligand: %s, pocket: %s' % (ligand_file, pocket_file))
print('\n\n')
if configs.affinities:
affinities = pd.read_csv(configs.affinities)
if 'affinity' not in affinities.columns:
raise ValueError('There is no `affinity` column in the table')
if 'name' not in affinities.columns:
raise ValueError('There is no `name` column in the table')
affinities = affinities.set_index('name')['affinity']
else:
affinities = None
with h5py.File(configs.output, configs.mode) as f:
pocket_generator = get_pocket(configs, num_pockets, featurizer=featurizer, num_ligands=num_ligands)
for ligand_file in configs.ligand:
# use filename without extension as dataset name
name = os.path.splitext(os.path.split(ligand_file)[1])[0]
if configs.verbose:
print('reading %s' % ligand_file)
try:
ligand = next(pybel.readfile(configs.ligand_format, ligand_file))
except:
raise IOError('Cannot read %s file' % ligand_file)
ligand_coords, ligand_features = featurizer.get_features(ligand, molcode=1)
pocket_coords, pocket_features = next(pocket_generator)
centroid = ligand_coords.mean(axis=0)
ligand_coords -= centroid
pocket_coords -= centroid
data = np.concatenate(
(np.concatenate((ligand_coords, pocket_coords)),
np.concatenate((ligand_features, pocket_features))),
axis=1,
)
dataset = f.create_dataset(name, data=data, shape=data.shape,
dtype='float32', compression='lzf')
if affinities is not None:
dataset.attrs['affinity'] = affinities.loc[name]
if configs.verbose:
print('\n\ncreated %s with %s structures' % (configs.output, num_ligands))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(
description='Prepare molecular data for the network',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--ligand', '-l', required=True, type=input_file, nargs='+',
help='files with ligands\' structures')
parser.add_argument('--pocket', '-p', required=True, type=input_file, nargs='+',
help='files with pockets\' structures')
parser.add_argument('--ligand_format', type=str, default='mol2',
help='file format for the ligand,'
' must be supported by openbabel')
parser.add_argument('--pocket_format', type=str, default='mol2',
help='file format for the pocket,'
' must be supported by openbabel')
parser.add_argument('--output', '-o', default='./complexes.hdf',
type=output_file,
help='name for the file with the prepared structures')
parser.add_argument('--mode', '-m', default='w',
type=str, choices=['r+', 'w', 'w-', 'x', 'a'],
help='mode for the output file (see h5py documentation)')
parser.add_argument('--affinities', '-a', default=None, type=input_file,
help='CSV table with affinity values.'
' It must contain two columns: `name` which must be'
' equal to ligand\'s file name without extension,'
' and `affinity` which must contain floats')
parser.add_argument('--verbose', '-v', default=True, type=ast.literal_eval,
help='whether to print messages')
args = parser.parse_args()
prepare(configs=args)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""process raw pdbbindv2016 data"""
import os
import argparse
import warnings
import h5py
import pybel
import numpy as np
import pandas as pd
from src.data import Featurizer
def extractFeature(path, affinity_data, dataset_path):
featurizer = Featurizer()
charge_idx = featurizer.FEATURE_NAMES.index('partialcharge')
with h5py.File('%s/core2013.hdf' % path, 'w') as g:
j = 0
for dataset_name, data in affinity_data.groupby('set'):
print(dataset_name, 'set')
i = 0
ds_path = dataset_path[dataset_name]
print(ds_path)
with h5py.File('%s/%s.hdf' % (path, dataset_name), 'w') as f:
for _, row in data.iterrows():
name = row['pdbid']
affinity = row['Kd_Ki']
ligand = next(pybel.readfile('mol2', '%s/%s/%s/%s_ligand.mol2' % (path, ds_path, name, name)))
# do not add the hydrogens! they are in the structure and it would reset the charges
try:
pocket = next(pybel.readfile('mol2', '%s/%s/%s/%s_pocket.mol2' % (path, ds_path, name, name)))
# do not add the hydrogens! they were already added in chimera and it would reset the charges
except ValueError:
warnings.warn('no pocket for %s, %s (%s set)' % (dataset_name, name, dataset_name))
continue
ligand_coords, ligand_features = featurizer.get_features(ligand, molcode=1)
assert (ligand_features[:, charge_idx] != 0).any()
pocket_coords, pocket_features = featurizer.get_features(pocket, molcode=-1)
centroid = ligand_coords.mean(axis=0)
ligand_coords -= centroid
pocket_coords -= centroid
data = np.concatenate((np.concatenate((ligand_coords, pocket_coords)),
np.concatenate((ligand_features, pocket_features))), axis=1)
if row['include']:
dataset = f.create_dataset(name, data=data, shape=data.shape,
dtype='float32', compression='lzf')
dataset.attrs['affinity'] = affinity
i += 1
else:
dataset = g.create_dataset(name, data=data, shape=data.shape,
dtype='float32', compression='lzf')
dataset.attrs['affinity'] = affinity
j += 1
print('prepared', i, 'complexes')
print('excluded', j, 'complexes')
def transpdb2mol2(path, dataset_name):
for dataset in dataset_name.values():
data_path = os.path.join(path, dataset)
for die_path, _, pdbfile in os.walk(data_path):
for pfile in pdbfile:
if "_pocket.pdb" in pfile:
p_real_file = os.path.join(die_path, pfile)
molfile = p_real_file.replace(".pdb", ".mol2")
command = "obabel -i pdb %s -o mol2 -O %s" % (p_real_file, molfile)
os.system(command)
print("Finish trans pdb to mol2 format.")
def ParseandClean(paths):
files = os.path.join(paths, 'PDBbind_2016_plain_text_index/index/INDEX_general_PL_data.2016')
if os.path.exists('./affinity_data.csv'):
os.remove('./affinity_data.csv')
# Save binding affinities to csv file
result = pd.DataFrame(columns=('pdbid', 'Kd_Ki'))
for line in open(files):
line = line.rstrip()
if line.startswith('#') or line == '':
continue
it = line.split(maxsplit=7)
pdbid, log_kdki = it[0], it[3]
result = result.append(
pd.DataFrame({'pdbid': [pdbid], 'Kd_Ki': [log_kdki]}),
ignore_index=True)
result.to_csv('affinity_data.csv', sep=",", index=False)
affinity_data = pd.read_csv('affinity_data.csv', comment='#')
# Find affinities without structural data (i.e. with missing directories)
missing = []
for misdata in affinity_data['pdbid']:
gser = os.path.join(paths, f'general-set-except-refined/{misdata}')
refined_set = os.path.join(paths, f'refined-set/{misdata}')
if not os.path.exists(gser) and not os.path.exists(refined_set):
missing.append(misdata)
missing = set(missing)
affinity_data = affinity_data[~np.in1d(affinity_data['pdbid'], list(missing))]
print("Missing length: ", len(missing))
print(affinity_data['Kd_Ki'].isnull().any())
# Separate core, refined, and general sets
core_file = os.path.join(paths, 'PDBbind_2016_plain_text_index/index/INDEX_core_data.2016')
core_set = []
for c_line in open(core_file):
c_line = c_line.rstrip()
if c_line.startswith('#') or c_line == '':
continue
c_it = c_line.split(maxsplit=7)
core_set.append(c_it[0])
core_set = set(core_set)
print('Core Set length: ', len(core_set))
refined_file = os.path.join(paths, 'PDBbind_2016_plain_text_index/index/INDEX_refined_data.2016')
refined_set = []
for rf_line in open(refined_file):
rf_line = rf_line.rstrip()
if rf_line.startswith('#') or rf_line == '':
continue
rf_it = rf_line.split(maxsplit=7)
refined_set.append(rf_it[0])
refined_set = set(refined_set)
general_set = set(affinity_data['pdbid'])
assert core_set & refined_set == core_set
assert refined_set & general_set == refined_set
print("Refined Set Length: ", len(refined_set))
print("General Set Length: ", len(general_set))
#exclude v2013 core set -- it will be used as another test set
core2013_file = os.path.join(paths, 'core_pdbbind2013.ids')
core2013 = []
for c2_line in open(core2013_file):
c2_it = c2_line.rstrip()
core2013.append(c2_it)
core2013 = set(core2013)
print("Core2013 length: ", len(core2013))
print(affinity_data.head())
print(len(core2013 & (general_set - core_set)))
affinity_data['include'] = True
affinity_data.loc[np.in1d(affinity_data['pdbid'], list(core2013 & (general_set - core_set))), 'include'] = False
affinity_data.loc[np.in1d(affinity_data['pdbid'], list(general_set)), 'set'] = 'general'
affinity_data.loc[np.in1d(affinity_data['pdbid'], list(refined_set)), 'set'] = 'refined'
affinity_data.loc[np.in1d(affinity_data['pdbid'], list(core_set)), 'set'] = 'core'
print(affinity_data.head())
print(affinity_data[affinity_data['include']].groupby('set').apply(len).loc[['general', 'refined', 'core']])
if os.path.exists('./affinity_data_cleaned.csv'):
os.remove('./affinity_data_cleaned.csv')
affinity_data[['pdbid']].to_csv('pdb.ids', header=False, index=False)
affinity_data[['pdbid', 'Kd_Ki', 'set']].to_csv('affinity_data_cleaned.csv', index=False)
#Parse Molecules
dataset_path = {'general': 'general-set-except-refined', 'refined': 'refined-set', 'core': 'refined-set'}
transpdb2mol2(paths, dataset_path)
extractFeature(path=paths, affinity_data=affinity_data, dataset_path=dataset_path)
print("Finish process data.")
with h5py.File('%s/core.hdf' % paths, 'r') as f, \
h5py.File('%s/core2013.hdf' % paths, 'r+') as g:
for name in f:
if name in core2013:
dataset = g.create_dataset(name, data=f[name])
dataset.attrs['affinity'] = f[name].attrs['affinity']
print("Finish All..........")
def Extrct2013ids(in_paths):
"""Extract pdbbind2013 index"""
filepath = os.path.join(in_paths, './v2013-core')
file_idx = os.listdir(filepath)
for items in file_idx:
with open(os.path.join(in_paths, 'core_pdbbind2013.ids'), 'a') as f:
f.write(items+'\n')
print("extract 2013 index done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Preprocess pdbbind data")
parser.add_argument('--data_path', type=str, required=True, default='',
help='Dataset process.')
args = parser.parse_args()
data_paths = args.data_path
if not os.path.exists(os.path.join(data_paths, 'PDBbind_2016_plain_text_index/index/INDEX_general_PL_data.2016')):
raise IOError("INDEX_general_PL_data.2016 file doesn't exit!")
if not os.path.exists(os.path.join(data_paths, 'PDBbind_2016_plain_text_index/index/INDEX_core_data.2016')):
raise IOError("INDEX_core_data.2016 file doesn't exit!")
if not os.path.exists(os.path.join(data_paths, 'PDBbind_2016_plain_text_index/index/INDEX_refined_data.2016')):
raise IOError("INDEX_refined_data.2016 file doesn't exit!")
if os.path.exists(os.path.join(data_paths, 'core_pdbbind2013.ids')):
print("Remove Exist core_pdbbind2013.ids file.")
os.remove(os.path.join(data_paths, 'core_pdbbind2013.ids'))
Extrct2013ids(data_paths)
ParseandClean(data_paths)
openbabel==2.4.1
seaborn==0.11.2
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train.sh MINDRECORD_PATH RANK_TABLE DEVICE_NUM"
echo "For example: bash run_distribute_train.sh /path/mindrecord_path /path/rank_table device_num"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
if [[ $# -lt 2 ]]; then
echo "Usage: bash run_distribute_train.sh [MINDRECORD_PATH] [RANK_TABLE] [DEVICE_NUM]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET=$(get_real_path $1)
ulimit -u unlimited
export DEVICE_NUM=$3
export RANK_SIZE=$3
export RANK_TABLE_FILE=$(get_real_path $2)
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
echo "config file path $CONFIG_FILE"
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir ./device$i
cp ../*.py ./device$i
cp ../*.yaml ./device$i
cp -r ../src ./device$i
cp -r ../scripts/*.sh ./device$i
cd ./device$i || exit
mkdir ckpt
echo "start training for device $DEVICE_ID"
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
nohup python3 -u train.py --config_path=$CONFIG_FILE --mindrecord_path=${DATASET} \
--enable_modelarts=False --batch_size=8 \
--distribute=True > log.txt 2>&1 &
echo "$i finish"
cd ../
done
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_eval.sh MINDRECORD_PATH CKPT_PATH DEVICE_ID"
echo "for example: bash run_eval.sh /data/path checkpoint/path device_id"
echo "It is better to use absolute path."
echo "Please pay attention that the dataset should corresponds to dataset_name"
echo "=============================================================================================================="
if [[ $# -lt 3 ]]; then
echo "Usage: bash run_eval.sh [MINDRECORD_PATH] [CKPT_PATH] [DEVICE_ID]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET=$(get_real_path $1)
CKPT=$(get_real_path $2)
DEVICEID=$3
export DEVICE_NUM=1
export DEVICE_ID=$DEVICEID
export RANK_ID=0
export RANK_SIZE=1
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
echo "config file path $CONFIG_FILE"
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp ../*.yaml ./eval
cp -r ../src ./eval
cp -r ../scripts/*.sh ./eval
cd ./eval || exit
mkdir ckpt
echo "start training for device $DEVICE_ID"
env > env.log
python3 -u eval.py --config_path=$CONFIG_FILE --mindrecord_path=${DATASET} \
--enable_modelarts=False --ckpt_file=$CKPT --device_id=$DEVICEID > log.txt 2>&1 &
cd ..
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_standalone_train.sh MINDRECORD_PATH"
echo "for example: sh run_standalone_train.sh /data/mindrecord_path device_id"
echo "It is better to use absolute path."
echo "Please pay attention that the dataset should corresponds to dataset_name"
echo "=============================================================================================================="
if [[ $# -lt 2 ]]; then
echo "Usage: bash run_standalone_train.sh [MINDRECORD_PATH] [DEVICE_ID]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET=$(get_real_path $1)
DEVICEID=$2
export DEVICE_NUM=1
export DEVICE_ID=$DEVICEID
export RANK_ID=0
export RANK_SIZE=1
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
echo "config file path $CONFIG_FILE"
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp ../*.yaml ./train
cp -r ../src ./train
cp -r ../scripts/*.sh ./train
cd ./train || exit
mkdir ckpt
echo "start training for device $DEVICE_ID"
env > env.log
python3 -u train.py --config_path=$CONFIG_FILE --mindrecord_path=${DATASET} \
--enable_modelarts=False --distribute=False --device_id=$DEVICEID > log.txt 2>&1 &
cd ..
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create mindrecord"""
import os.path
import pandas as pd
import numpy as np
from mindspore.mindrecord import FileWriter
from src.model_utils.config import config
from src.dataloader import preprocess_dataset
class DatasetIter:
def __init__(self, coor_features, affinity):
self.coor_features = coor_features
self.affinity = affinity
def __getitem__(self, index):
return self.coor_features[index], self.affinity[index]
def __len__(self):
return len(self.coor_features)
def create_mindrecord():
train_coords_features, y_train_size, train_no_rotationcoords_features, \
val_coords_features, no_batch_size = preprocess_dataset(config, config.data_path,
batch_rotation=list(range(config.rotations)),
batch_no_rotation=0, v_batch_rotation=0)
print("train size: ", y_train_size, flush=True)
print("Validation size: ", no_batch_size, flush=True)
train_rotation_path = os.path.join(config.mindrecord_path, 'train_rotation')
train_no_rotation_path = os.path.join(config.mindrecord_path, 'no_rotation')
val_path = os.path.join(config.mindrecord_path, 'val')
if not os.path.exists(train_rotation_path):
os.mkdir(train_rotation_path)
if not os.path.exists(train_no_rotation_path):
os.mkdir(train_no_rotation_path)
if not os.path.exists(val_path):
os.mkdir(val_path)
train_rot_writer = FileWriter(os.path.join(train_rotation_path, 'train_rotation_dataset.mindrecord'), shard_num=1)
rot_train_data_schema = {
"coords_features": {"type": "float32", "shape": [19, 21, 21, 21]},
"affinitys": {"type": "float32", "shape": [-1]}
}
train_rot_writer.add_schema(rot_train_data_schema, "pdbbind_rot")
data_iterator = DatasetIter(train_coords_features['coords_features'], train_coords_features['affinitys'])
train_rot_item = {'coords_features': [], 'affinitys': []}
for coor_feature, affine in data_iterator:
train_rot_item['coords_features'] = np.array(coor_feature, dtype=np.float32)
train_rot_item['affinitys'] = np.array(affine, dtype=np.float32)
train_rot_writer.write_raw_data([train_rot_item])
train_rot_writer.commit()
print("Rotation training mindrecord create finished!", flush=True)
train_norot_writer = FileWriter(os.path.join(train_no_rotation_path,
'train_norotation_dataset.mindrecord'), shard_num=1)
norot_train_data_schema = {
"coords_features": {"type": "float32", "shape": [19, 21, 21, 21]},
"affinitys": {"type": "float32", "shape": [-1]}
}
train_norot_writer.add_schema(norot_train_data_schema, "pdbbind_norot")
norot_data_iterator = DatasetIter(train_no_rotationcoords_features['coords_features'],
train_no_rotationcoords_features['affinitys'])
train_no_rot_item = {'coords_features': [], 'affinitys': []}
for coor_feature, affine in norot_data_iterator:
train_no_rot_item['coords_features'] = np.array(coor_feature, dtype=np.float32)
train_no_rot_item['affinitys'] = np.array(affine, dtype=np.float32)
train_norot_writer.write_raw_data([train_no_rot_item])
train_norot_writer.commit()
print("No rotation training mindrecord create finished!", flush=True)
val_writer = FileWriter(os.path.join(val_path, 'validation_dataset.mindrecord'), shard_num=1)
val_data_schema = {
"coords_features": {"type": "float32", "shape": [19, 21, 21, 21]},
"affinitys": {"type": "float32", "shape": [-1]}
}
val_writer.add_schema(val_data_schema, "pdbbind_val")
val_data_iterator = DatasetIter(val_coords_features['coords_features'],
val_coords_features['affinitys'])
val_rot_item = {'coords_features': [], 'affinitys': []}
for coor_feature, affine in val_data_iterator:
val_rot_item['coords_features'] = np.array(coor_feature, dtype=np.float32)
val_rot_item['affinitys'] = np.array(affine, dtype=np.float32)
val_writer.write_raw_data([val_rot_item])
val_writer.commit()
size_list = [{'dataset': 'train_size', "size": y_train_size},
{'dataset': 'val_size', "size": no_batch_size}]
results = pd.DataFrame(size_list, columns=['dataset', 'size'])
results.to_csv(os.path.join(config.mindrecord_path, 'ds_size.csv'), index=False)
print("Validation mindrecord create finished!", flush=True)
if __name__ == '__main__':
create_mindrecord()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""calcaulates atomic features for molecules"""
import pickle
from itertools import combinations
from math import ceil, sin, cos, sqrt, pi
import numpy as np
import pybel
class Featurizer():
"""Calcaulates atomic features for molecules. Features can encode atom type,
native pybel properties or any property defined with SMARTS patterns
Attributes
----------
FEATURE_NAMES: list of strings
Labels for features (in the same order as features)
NUM_ATOM_CLASSES: int
Number of atom codes
ATOM_CODES: dict
Dictionary mapping atomic numbers to codes
NAMED_PROPS: list of string
Names of atomic properties to retrieve from pybel.Atom object
CALLABLES: list of callables
Callables used to calculate custom atomic properties
SMARTS: list of SMARTS strings
SMARTS patterns defining additional atomic properties
"""
def __init__(self, atom_codes=None, atom_labels=None,
named_properties=None, save_molecule_codes=True,
custom_properties=None, smarts_properties=None,
smarts_labels=None):
"""Creates Featurizer with specified types of features. Elements of a
feature vector will be in a following order: atom type encoding
(defined by atom_codes), Pybel atomic properties (defined by
named_properties), molecule code (if present), custom atomic properties
(defined `custom_properties`), and additional properties defined with
SMARTS (defined with `smarts_properties`).
Parameters
----------
atom_codes: dict, optional
Dictionary mapping atomic numbers to codes. It will be used for
one-hot encoging therefore if n different types are used, codes
shpuld be from 0 to n-1. Multiple atoms can have the same code,
e.g. you can use {6: 0, 7: 1, 8: 1} to encode carbons with [1, 0]
and nitrogens and oxygens with [0, 1] vectors. If not provided,
default encoding is used.
atom_labels: list of strings, optional
Labels for atoms codes. It should have the same length as the
number of used codes, e.g. for `atom_codes={6: 0, 7: 1, 8: 1}` you
should provide something like ['C', 'O or N']. If not specified
labels 'atom0', 'atom1' etc are used. If `atom_codes` is not
specified this argument is ignored.
named_properties: list of strings, optional
Names of atomic properties to retrieve from pybel.Atom object. If
not specified ['hyb', 'heavyvalence', 'heterovalence',
'partialcharge'] is used.
save_molecule_codes: bool, optional (default True)
If set to True, there will be an additional feature to save
molecule code. It is usefeul when saving molecular complex in a
single array.
custom_properties: list of callables, optional
Custom functions to calculate atomic properties. Each element of
this list should be a callable that takes pybel.Atom object and
returns a float. If callable has `__name__` property it is used as
feature label. Otherwise labels 'func<i>' etc are used, where i is
the index in `custom_properties` list.
smarts_properties: list of strings, optional
Additional atomic properties defined with SMARTS patterns. These
patterns should match a single atom. If not specified, default
patterns are used.
smarts_labels: list of strings, optional
Labels for properties defined with SMARTS. Should have the same
length as `smarts_properties`. If not specified labels 'smarts0',
'smarts1' etc are used. If `smarts_properties` is not specified
this argument is ignored.
"""
# Remember namse of all features in the correct order
self.FEATURE_NAMES = []
if atom_codes is not None:
if not isinstance(atom_codes, dict):
raise TypeError('Atom codes should be dict, got %s instead'
% type(atom_codes))
codes = set(atom_codes.values())
for i in range(len(codes)):
if i not in codes:
raise ValueError('Incorrect atom code %s' % i)
self.NUM_ATOM_CLASSES = len(codes)
self.ATOM_CODES = atom_codes
if atom_labels is not None:
if len(atom_labels) != self.NUM_ATOM_CLASSES:
raise ValueError('Incorrect number of atom labels: '
'%s instead of %s'
% (len(atom_labels), self.NUM_ATOM_CLASSES))
else:
atom_labels = ['atom%s' % i for i in range(self.NUM_ATOM_CLASSES)]
self.FEATURE_NAMES += atom_labels
else:
self.ATOM_CODES = {}
metals = ([3, 4, 11, 12, 13] + list(range(19, 32))
+ list(range(37, 51)) + list(range(55, 84))
+ list(range(87, 104)))
# List of tuples (atomic_num, class_name) with atom types to encode.
atom_classes = [
(5, 'B'),
(6, 'C'),
(7, 'N'),
(8, 'O'),
(15, 'P'),
(16, 'S'),
(34, 'Se'),
([9, 17, 35, 53], 'halogen'),
(metals, 'metal')
]
for code, (atom, name) in enumerate(atom_classes):
if isinstance(atom, list):
for a in atom:
self.ATOM_CODES[a] = code
else:
self.ATOM_CODES[atom] = code
self.FEATURE_NAMES.append(name)
self.NUM_ATOM_CLASSES = len(atom_classes)
if named_properties is not None:
if not isinstance(named_properties, (list, tuple, np.ndarray)):
raise TypeError('named_properties must be a list')
allowed_props = [prop for prop in dir(pybel.Atom)
if not prop.startswith('__')]
for prop_id, prop in enumerate(named_properties):
if prop not in allowed_props:
raise ValueError(
'named_properties must be in pybel.Atom attributes,'
' %s was given at position %s' % (prop_id, prop)
)
self.NAMED_PROPS = named_properties
else:
self.NAMED_PROPS = ['hyb', 'heavyvalence', 'heterovalence',
'partialcharge']
self.FEATURE_NAMES += self.NAMED_PROPS
if not isinstance(save_molecule_codes, bool):
raise TypeError('save_molecule_codes should be bool, got %s '
'instead' % type(save_molecule_codes))
self.save_molecule_codes = save_molecule_codes
if save_molecule_codes:
# Remember if an atom belongs to the ligand or to the protein
self.FEATURE_NAMES.append('molcode')
self.CALLABLES = []
if custom_properties is not None:
for i, func in enumerate(custom_properties):
if not callable(func):
raise TypeError('custom_properties should be list of'
' callables, got %s instead' % type(func))
name = getattr(func, '__name__', '')
if name == '':
name = 'func%s' % i
self.CALLABLES.append(func)
self.FEATURE_NAMES.append(name)
if smarts_properties is None:
# SMARTS definition for other properties
self.SMARTS = [
'[#6+0!$(*~[#7,#8,F]),SH0+0v2,s+0,S^3,Cl+0,Br+0,I+0]',
'[a]',
'[!$([#1,#6,F,Cl,Br,I,o,s,nX3,#7v5,#15v5,#16v4,#16v6,*+1,*+2,*+3])]',
'[!$([#6,H0,-,-2,-3]),$([!H0;#7,#8,#9])]',
'[r]'
]
smarts_labels = ['hydrophobic', 'aromatic', 'acceptor', 'donor',
'ring']
elif not isinstance(smarts_properties, (list, tuple, np.ndarray)):
raise TypeError('smarts_properties must be a list')
else:
self.SMARTS = smarts_properties
if smarts_labels is not None:
if len(smarts_labels) != len(self.SMARTS):
raise ValueError('Incorrect number of SMARTS labels: %s'
' instead of %s'
% (len(smarts_labels), len(self.SMARTS)))
else:
smarts_labels = ['smarts%s' % i for i in range(len(self.SMARTS))]
# Compile patterns
self.compile_smarts()
self.FEATURE_NAMES += smarts_labels
def compile_smarts(self):
self.__PATTERNS = []
for smarts in self.SMARTS:
self.__PATTERNS.append(pybel.Smarts(smarts))
def encode_num(self, atomic_num):
"""Encode atom type with a binary vector. If atom type is not included in
the `atom_classes`, its encoding is an all-zeros vector.
Parameters
----------
atomic_num: int
Atomic number
Returns
-------
encoding: np.ndarray
Binary vector encoding atom type (one-hot or null).
"""
if not isinstance(atomic_num, int):
raise TypeError('Atomic number must be int, %s was given'
% type(atomic_num))
encoding = np.zeros(self.NUM_ATOM_CLASSES)
try:
encoding[self.ATOM_CODES[atomic_num]] = 1.0
except: #pylint: disable=bare-except
pass
return encoding
def find_smarts(self, molecule):
"""Find atoms that match SMARTS patterns.
Parameters
----------
molecule: pybel.Molecule
Returns
-------
features: np.ndarray
NxM binary array, where N is the number of atoms in the `molecule`
and M is the number of patterns. `features[i, j]` == 1.0 if i'th
atom has j'th property
"""
if not isinstance(molecule, pybel.Molecule):
raise TypeError('molecule must be pybel.Molecule object, %s was given'
% type(molecule))
features = np.zeros((len(molecule.atoms), len(self.__PATTERNS)))
for (pattern_id, pattern) in enumerate(self.__PATTERNS):
atoms_with_prop = np.array(list(*zip(*pattern.findall(molecule))),
dtype=int) - 1
features[atoms_with_prop, pattern_id] = 1.0
return features
def get_features(self, molecule, molcode=None):
"""Get coordinates and features for all heavy atoms in the molecule.
Parameters
----------
molecule: pybel.Molecule
molcode: float, optional
Molecule type. You can use it to encode whether an atom belongs to
the ligand (1.0) or to the protein (-1.0) etc.
Returns
-------
coords: np.ndarray, shape = (N, 3)
Coordinates of all heavy atoms in the `molecule`.
features: np.ndarray, shape = (N, F)
Features of all heavy atoms in the `molecule`: atom type
(one-hot encoding), pybel.Atom attributes, type of a molecule
(e.g protein/ligand distinction), and other properties defined with
SMARTS patterns
"""
if not isinstance(molecule, pybel.Molecule):
raise TypeError('molecule must be pybel.Molecule object,'
' %s was given' % type(molecule))
if molcode is None:
if self.save_molecule_codes is True:
raise ValueError('save_molecule_codes is set to True,'
' you must specify code for the molecule')
elif not isinstance(molcode, (float, int)):
raise TypeError('motlype must be float, %s was given'
% type(molcode))
coords = []
features = []
heavy_atoms = []
for i, atom in enumerate(molecule):
# ignore hydrogens and dummy atoms (they have atomicnum set to 0)
if atom.atomicnum > 1:
heavy_atoms.append(i)
coords.append(atom.coords)
features.append(np.concatenate((
self.encode_num(atom.atomicnum),
[atom.__getattribute__(prop) for prop in self.NAMED_PROPS],
[func(atom) for func in self.CALLABLES],
)))
coords = np.array(coords, dtype=np.float32)
features = np.array(features, dtype=np.float32)
if self.save_molecule_codes:
features = np.hstack((features,
molcode * np.ones((len(features), 1))))
features = np.hstack([features,
self.find_smarts(molecule)[heavy_atoms]])
if np.isnan(features).any():
raise RuntimeError('Got NaN when calculating features')
return coords, features
def to_pickle(self, fname='featurizer.pkl'):
"""Save featurizer in a given file. Featurizer can be restored with
`from_pickle` method.
Parameters
----------
fname: str, optional
Path to file in which featurizer will be saved
"""
# patterns can't be pickled, we need to temporarily remove them
patterns = self.__PATTERNS[:]
del self.__PATTERNS
try:
with open(fname, 'wb') as f:
pickle.dump(self, f)
finally:
self.__PATTERNS = patterns[:]
@staticmethod
def from_pickle(fname):
"""Load pickled featurizer from a given file
Parameters
----------
fname: str, optional
Path to file with saved featurizer
Returns
-------
featurizer: Featurizer object
Loaded featurizer
"""
with open(fname, 'rb') as f:
featurizer = pickle.load(f)
featurizer.compile_smarts()
return featurizer
def rotation_matrix(in_axis, in_theta):
"""Counterclockwise rotation about a given axis by theta radians"""
if not isinstance(in_axis, (np.ndarray, list, tuple)):
raise TypeError('axis must be an array of floats of shape (3,)')
try:
in_axis = np.asarray(in_axis, dtype=np.float)
except ValueError:
raise ValueError('axis must be an array of floats of shape (3,)')
if in_axis.shape != (3,):
raise ValueError('axis must be an array of floats of shape (3,)')
if not isinstance(in_theta, (float, int)):
raise TypeError('theta must be a float')
in_axis = in_axis / sqrt(np.dot(in_axis, in_axis))
a = cos(in_theta / 2.0)
b, c, d = -in_axis * sin(in_theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
# Create matrices for all possible 90* rotations of a box
ROTATIONS = [rotation_matrix([1, 1, 1], 0)]
# about X, Y and Z - 9 rotations
for a1 in range(3):
for t in range(1, 4):
axis = np.zeros(3)
axis[a1] = 1
theta = t * pi / 2.0
ROTATIONS.append(rotation_matrix(axis, theta))
# about each face diagonal - 6 rotations
for (a1, a2) in combinations(range(3), 2):
axis = np.zeros(3)
axis[[a1, a2]] = 1.0
theta = pi
ROTATIONS.append(rotation_matrix(axis, theta))
axis[a2] = -1.0
ROTATIONS.append(rotation_matrix(axis, theta))
# about each space diagonal - 8 rotations
for t in [1, 2]:
theta = t * 2 * pi / 3
axis = np.ones(3)
ROTATIONS.append(rotation_matrix(axis, theta))
for a1 in range(3):
axis = np.ones(3)
axis[a1] = -1
ROTATIONS.append(rotation_matrix(axis, theta))
def rotate(coords, rotation):
"""Rotate coordinates by a given rotation
Parameters
----------
coords: array-like, shape (N, 3)
Arrays with coordinates and features for each atoms.
rotation: int or array-like, shape (3, 3)
Rotation to perform. You can either select predefined rotation by
giving its index or specify rotation matrix.
Returns
-------
coords: np.ndarray, shape = (N, 3)
Rotated coordinates.
"""
global ROTATIONS
if not isinstance(coords, (np.ndarray, list, tuple)):
raise TypeError('coords must be an array of floats of shape (N, 3)')
try:
coords = np.asarray(coords, dtype=np.float)
except ValueError:
raise ValueError('coords must be an array of floats of shape (N, 3)')
shape = coords.shape
if len(shape) != 2 or shape[1] != 3:
raise ValueError('coords must be an array of floats of shape (N, 3)')
if isinstance(rotation, int):
if 0 <= rotation < len(ROTATIONS):
out = np.dot(coords, ROTATIONS[rotation])
else:
raise ValueError('Invalid rotation number %s!' % rotation)
elif isinstance(rotation, np.ndarray) and rotation.shape == (3, 3):
out = np.dot(coords, rotation)
else:
raise ValueError('Invalid rotation %s!' % rotation)
return out
def make_grid(coords, features, grid_resolution=1.0, max_dist=10.0):
"""Convert atom coordinates and features represented as 2D arrays into a
fixed-sized 3D box.
Parameters
----------
coords, features: array-likes, shape (N, 3) and (N, F)
Arrays with coordinates and features for each atoms.
grid_resolution: float, optional
Resolution of a grid (in Angstroms).
max_dist: float, optional
Maximum distance between atom and box center. Resulting box has size of
2*`max_dist`+1 Angstroms and atoms that are too far away are not
included.
Returns
-------
coords: np.ndarray, shape = (M, M, M, F)
4D array with atom properties distributed in 3D space. M is equal to
2 * `max_dist` / `grid_resolution` + 1
"""
try:
coords = np.asarray(coords, dtype=np.float)
except ValueError:
raise ValueError('coords must be an array of floats of shape (N, 3)')
c_shape = coords.shape
if len(c_shape) != 2 or c_shape[1] != 3:
raise ValueError('coords must be an array of floats of shape (N, 3)')
N = len(coords)
try:
features = np.asarray(features, dtype=np.float)
except ValueError:
raise ValueError('features must be an array of floats of shape (N, F)')
f_shape = features.shape
if len(f_shape) != 2 or f_shape[0] != N:
raise ValueError('features must be an array of floats of shape (N, F)')
if not isinstance(grid_resolution, (float, int)):
raise TypeError('grid_resolution must be float')
if grid_resolution <= 0:
raise ValueError('grid_resolution must be positive')
if not isinstance(max_dist, (float, int)):
raise TypeError('max_dist must be float')
if max_dist <= 0:
raise ValueError('max_dist must be positive')
num_features = f_shape[1]
max_dist = float(max_dist)
grid_resolution = float(grid_resolution)
box_size = ceil(2 * max_dist / grid_resolution + 1)
# move all atoms to the nearest grid point
grid_coords = (coords + max_dist) / grid_resolution
grid_coords = grid_coords.round().astype(int)
# remove atoms outside the box
in_box = ((grid_coords >= 0) & (grid_coords < box_size)).all(axis=1)
grid = np.zeros((1, box_size, box_size, box_size, num_features),
dtype=np.float32)
for (x, y, z), f in zip(grid_coords[in_box], features[in_box]):
grid[0, x, y, z] += f
return grid
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset loader"""
import os
import numpy as np
import h5py
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
from mindspore.common import dtype as mstype
from src.data import Featurizer, make_grid, rotate
def get_batch(configs, dataset_name, indices, coords, features, columns, std, affin, rotations=0):
train_x_final = {'coords_features': [], 'affinitys': []}
for rotation in rotations:
for i, idx in enumerate(indices):
x = []
coords_idx = rotate(coords[dataset_name][idx], rotation)
features_idx = features[dataset_name][idx]
x.append(make_grid(coords_idx, features_idx,
grid_resolution=configs.grid_spacing,
max_dist=configs.max_dist))
x = np.vstack(x)
x[..., columns['partialcharge']] /= std
train_x_final['coords_features'].append(np.transpose(np.squeeze(x), axes=(3, 0, 1, 2)))
train_x_final['affinitys'].append(affin[i])
return train_x_final
def extract_features(configs, dataset_name, idx, coords, features, columns, std, rotation):
x = []
coords_idx = rotate(coords[dataset_name][idx], rotation)
features_idx = features[dataset_name][idx]
x.append(make_grid(coords_idx, features_idx,
grid_resolution=configs.grid_spacing,
max_dist=configs.max_dist))
x = np.vstack(x)
x[..., columns['partialcharge']] /= std
x = np.transpose(np.squeeze(x), axes=(3, 0, 1, 2))
return x
def get_batchs(configs, dataset_name, indices, coords, features, columns, std, affin, rotations=0):
train_x_final = {'coords_features': [], 'affinitys': []}
if isinstance(rotations, int):
for i, idx in enumerate(indices):
x = extract_features(configs, dataset_name, idx, coords, features, columns, std, rotations)
train_x_final['coords_features'].append(x)
train_x_final['affinitys'].append(affin[i])
else:
for rotation in rotations:
for i, idx in enumerate(indices):
x = extract_features(configs, dataset_name, idx, coords, features, columns, std, rotation)
train_x_final['coords_features'].append(x)
train_x_final['affinitys'].append(affin[i])
return train_x_final
def preprocess_dataset(configs, paths, batch_rotation, batch_no_rotation, v_batch_rotation):
"""dataset preprocess"""
datasets_stage = ['validation', 'training', 'test']
ids = {}
affinity = {}
coords = {}
features = {}
featurizer = Featurizer()
print("atomic properties: ", featurizer.FEATURE_NAMES)
columns = {name: i for i, name in enumerate(featurizer.FEATURE_NAMES)}
for dictionary in [ids, affinity, coords, features]:
for datasets_name in datasets_stage:
dictionary[datasets_name] = []
paths = os.path.abspath(paths)
for dataset_name in datasets_stage:
dataset_path = os.path.join(paths, dataset_name + '_set.hdf')
with h5py.File(dataset_path, 'r') as f:
for pdb_id in f:
dataset = f[pdb_id]
coords[dataset_name].append(dataset[:, :3])
features[dataset_name].append(dataset[:, 3:])
affinity[dataset_name].append(dataset.attrs['affinity'])
ids[dataset_name].append(pdb_id)
ids[dataset_name] = np.array(ids[dataset_name])
affinity[dataset_name] = np.reshape(affinity[dataset_name], (-1, 1))
charges = []
for feature_data in features['training']:
charges.append(feature_data[..., columns['partialcharge']])
charges = np.concatenate([c.flatten() for c in charges])
charges_mean = charges.mean()
charges_std = charges.std()
print("charges mean=%s, std=%s" % (charges_mean, charges_std))
print("Using charges std as scaling factor")
# Best error we can get without any training (MSE from training set mean):
t_baseline = ((affinity['training'] - affinity['training'].mean()) ** 2.0).mean()
v_baseline = ((affinity['validation'] - affinity['training'].mean()) ** 2.0).mean()
print('baseline mse: training=%s, validation=%s' % (t_baseline, v_baseline))
ds_sizes = {dataset: len(affinity[dataset]) for dataset in datasets_stage}
# val set
val_y = affinity['validation']
no_batch_size = ds_sizes['validation']
ds_sizes_range = list(range(no_batch_size))
val_coords_features = get_batchs(configs, 'validation', ds_sizes_range,
coords, features, columns, charges_std, val_y, v_batch_rotation)
# train set with rotation
train_y = affinity['training']
y_train_size = ds_sizes['training']
train_sizes_range = list(range(y_train_size))
train_coords_features = get_batchs(configs, 'training', train_sizes_range, coords,
features, columns, charges_std, train_y, batch_rotation)
# train set without rotation
train_no_rotationcoords_features = get_batchs(configs, 'training', train_sizes_range, coords,
features, columns, charges_std, train_y, batch_no_rotation)
return train_coords_features, y_train_size, train_no_rotationcoords_features, val_coords_features, no_batch_size
class DatasetIter:
"""dataset iterator"""
def __init__(self, coor_features, affinity):
self.coor_features = coor_features
self.affinity = affinity
def __getitem__(self, index):
return self.coor_features[index], self.affinity[index]
def __len__(self):
return len(self.coor_features)
def minddataset_loader(configs, mindfile, no_batch_size):
"""rotation and without rotation dataset loader"""
rank_size, rank_id = _get_rank_info()
no_rot_weight = configs.batch_size / no_batch_size
train_loader = ds.MindDataset(mindfile, columns_list=["coords_features", "affinitys"],
num_parallel_workers=8, num_shards=rank_size, shard_id=rank_id)
type_cast_op = C.TypeCast(mstype.float32)
train_loader = train_loader.map(input_columns='coords_features', operations=type_cast_op)
train_loader = train_loader.map(input_columns='affinitys', operations=type_cast_op)
train_loader = train_loader.batch(batch_size=configs.batch_size, drop_remainder=True)
return train_loader, no_rot_weight
def minddataset_loader_val(configs, mindfile, no_batch_size):
"""validation dataset loader"""
no_rot_weight = configs.batch_size / no_batch_size
train_loader = ds.MindDataset(mindfile, columns_list=["coords_features", "affinitys"],
num_parallel_workers=8)
type_cast_op = C.TypeCast(mstype.float32)
train_loader = train_loader.map(input_columns='coords_features', operations=type_cast_op)
train_loader = train_loader.map(input_columns='affinitys', operations=type_cast_op)
train_loader = train_loader.batch(batch_size=configs.batch_size)
return train_loader, no_rot_weight
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
from mindspore.communication.management import get_rank, get_group_size
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom Logger."""
import os
import sys
import logging
from datetime import datetime
class LOGGER(logging.Logger):
"""
Logger.
Args:
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
self.rank = rank
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""Setup logging file."""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
"""Get Logger."""
logger = LOGGER('Pafnucy Network', rank)
logger.setup_logging_file(path, rank)
return logger
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
pprint(final_config)
print("Please check the above information for the configurations", flush=True)
return Config(final_config)
config = get_config()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment