-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add TGCN Model for Traffic Forecasting #972
base: develop
Are you sure you want to change the base?
Conversation
* add TGCN docs * add TGCN model * add TGCN example * add PEMSD4 & PEMSD8 dataset
Thanks for your contribution! |
docs/zh/examples/tgcn.md
Outdated
# Train | ||
python PaddleScience/examples/tgcn/run.py data_name=PEMSD8 | ||
# python PaddleScience/examples/tgcn/run.py data_name=PEMSD4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
案例的默认执行路径在对应的案例文件夹下,而不是PaddleScience同级的目录下
# Train | |
python PaddleScience/examples/tgcn/run.py data_name=PEMSD8 | |
# python PaddleScience/examples/tgcn/run.py data_name=PEMSD4 | |
python run.py data_name=PEMSD8 | |
# python run.py data_name=PEMSD4 |
docs/zh/examples/tgcn.md
Outdated
# Eval | ||
python PaddleScience/examples/tgcn/run.py data_name=PEMSD8 mode=eval | ||
# python PaddleScience/examples/tgcn/run.py data_name=PEMSD4 mode=eval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
案例的默认执行路径在对应的案例文件夹下,而不是PaddleScience同级的目录下
# Eval | |
python PaddleScience/examples/tgcn/run.py data_name=PEMSD8 mode=eval | |
# python PaddleScience/examples/tgcn/run.py data_name=PEMSD4 mode=eval | |
python run.py data_name=PEMSD8 mode=eval | |
# python run.py data_name=PEMSD4 mode=eval |
docs/zh/examples/tgcn.md
Outdated
开始训练、评估前,请下载数据集:[PEMSD4 & PEMSD8](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tgcn/tgcn_data.zip)。将解压后的数据集文件夹与 `PaddleScience` 文件夹放置于同一目录下。 | ||
|
||
开始评估前,请下载或训练生成预训练模型:[PEMSD4](https://paddle-org.bj.bcebos.com/paddlescience/models/tgcn/PEMSD4_pretrained_model.pdparams) & [PEMSD8](https://paddle-org.bj.bcebos.com/paddlescience/models/tgcn/PEMSD8_pretrained_model.pdparams)。将预训练模型文件与 `PaddleScience` 文件夹放置于同一目录下。 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docs/zh/examples/tgcn.md
Outdated
下表展示了 TGCN 在 PEMSD4 和 PEMSD8 两个数据集上的评估结果。 | ||
|
||
| 数据集 | MAE | RMSE | | ||
| :----- | :---- | :---- | | ||
| PEMSD4 | 21.48 | 34.06 | | ||
| PEMSD8 | 15.57 | 24.52 | | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examples/tgcn/conf/run.yaml
Outdated
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
dir: __exp__/${data_name}/${now:%Y_%m_%d_%H_%M_%S} | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working directory unchanged | ||
config: | ||
override_dirname: | ||
exclude_keys: | ||
- mode | ||
- output_dir | ||
- log_freq | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hydra: | |
run: | |
# dynamic output directory according to running time and override name | |
dir: __exp__/${data_name}/${now:%Y_%m_%d_%H_%M_%S} | |
job: | |
name: ${mode} # name of logfile | |
chdir: false # keep current working directory unchanged | |
config: | |
override_dirname: | |
exclude_keys: | |
- mode | |
- output_dir | |
- log_freq | |
sweep: | |
# output directory for multirun | |
dir: ${hydra.run.dir} | |
subdir: ./ | |
defaults: | |
- ppsci_default | |
- TRAIN: train_default | |
- TRAIN/ema: ema_default | |
- TRAIN/swa: swa_default | |
- EVAL: eval_default | |
- INFER: infer_default | |
- _self_ | |
hydra: | |
run: | |
# dynamic output directory according to running time and override name | |
dir: outputs_tgcn/${now:%Y-%m-%d}/${now:%H-%M-%S} | |
job: | |
name: ${mode} # name of logfile | |
chdir: false # keep current working directory unchanged | |
callbacks: | |
init_callback: | |
_target_: ppsci.utils.callbacks.InitCallback | |
sweep: | |
# output directory for multirun | |
dir: ${hydra.run.dir} | |
subdir: ./ |
examples/tgcn/run.py
Outdated
'input_keys': cfg.MODEL.afno.input_keys, | ||
'label_keys': cfg.MODEL.afno.label_keys, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
全局搜索修改一下:cfg.MODEL.afno.
--> cfg.MODEL.
examples/tgcn/run.py
Outdated
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(cfg.seed) | ||
|
||
# initialize logger | ||
logger.init_logger('ppsci', os.path.join(cfg.output_dir, 'test.log'), 'info') | ||
logger.message(cfg) | ||
|
||
# set eval dataloader config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# set random seed for reproducibility | |
ppsci.utils.misc.set_random_seed(cfg.seed) | |
# initialize logger | |
logger.init_logger('ppsci', os.path.join(cfg.output_dir, 'test.log'), 'info') | |
logger.message(cfg) | |
# set eval dataloader config | |
# set eval dataloader config |
examples/tgcn/run.py
Outdated
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(cfg.seed) | ||
|
||
# initialize logger | ||
logger.init_logger('ppsci', os.path.join(cfg.output_dir, 'train.log'), 'info') | ||
logger.message(cfg) | ||
|
||
# set train dataloader config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# set random seed for reproducibility | |
ppsci.utils.misc.set_random_seed(cfg.seed) | |
# initialize logger | |
logger.init_logger('ppsci', os.path.join(cfg.output_dir, 'train.log'), 'info') | |
logger.message(cfg) | |
# set train dataloader config | |
# set train dataloader config |
ppsci/arch/tgcn.py
Outdated
self.edge_index = pp.to_tensor(data=edge_index, place=cfg.device) | ||
self.edge_attr = pp.to_tensor(data=edge_attr, place=cfg.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个属性好像并没有被用到?
ppsci/arch/tgcn.py
Outdated
|
||
self.edge_index = pp.to_tensor(data=edge_index, place=cfg.device) | ||
self.edge_attr = pp.to_tensor(data=edge_attr, place=cfg.device) | ||
self.adj = pp.to_tensor(data=adj, place=cfg.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@casia-rxwang 签署一下CLA协议: |
@casia-rxwang 修改代码的时候顺便合并一下最新develop分支的代码,并解决一下冲突 |
1129e3d
to
527d6a0
Compare
PR types
Others
PR changes
Others
Describe