Skip to content

Latest commit

ย 

History

History
54 lines (28 loc) ยท 2.34 KB

README_baseline.md

File metadata and controls

54 lines (28 loc) ยท 2.34 KB

Baseline model of BoostCamp2021 P-Stage DST

Open-vocab based DST model์ธ TRADE์˜ ํ•œ๊ตญ์–ด ๊ตฌํ˜„์ฒด์ž…๋‹ˆ๋‹ค. (5๊ฐ•, 6๊ฐ• ๋‚ด์šฉ ์ฐธ๊ณ )

  • ๊ธฐ์กด์˜ GloVe, Char Embedding ๋Œ€์‹  monologg/koelectra-base-v3-discriminator์˜ token_embeddings์„pretrained Subword Embedding์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์•„๋ผ๊ธฐ ์œ„ํ•ด Token Embedding (768) => Hidden Dimension (400)์œผ๋กœ์˜ Projection layer๊ฐ€ ๋“ค์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋น ๋ฅธ ํ•™์Šต์„ ์œ„ํ•ด Parallel Decoding์ด ๊ตฌํ˜„๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

1. ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜

pip install -r requirements.txt

2. ๋ชจ๋ธ ํ•™์Šต

SM_CHANNEL_TRAIN=data/train_dataset SM_MODEL_DIR=[model saving dir] python train.py
ํ•™์Šต๋œ ๋ชจ๋ธ์€ epoch ๋ณ„๋กœ SM_MODEL_DIR/model-{epoch}.bin ์œผ๋กœ ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.
์ถ”๋ก ์— ํ•„์š”ํ•œ ๋ถ€๊ฐ€ ์ •๋ณด์ธ configuration๋“ค๋„ ๊ฐ™์€ ๊ฒฝ๋กœ์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.
Best Checkpoint Path๊ฐ€ ํ•™์Šต ๋งˆ์ง€๋ง‰์— ํ‘œ๊ธฐ๋ฉ๋‹ˆ๋‹ค.

3. ์ถ”๋ก ํ•˜๊ธฐ

SM_CHANNEL_EVAL=data/eval_dataset/public SM_CHANNEL_MODEL=[Model Checkpoint Path] SM_OUTPUT_DATA_DIR=[Output path] python inference.py

4. ์ œ์ถœํ•˜๊ธฐ

3๋ฒˆ ์Šคํ… inference.py์—์„œ SM_OUTPUT_DATA_DIR์— ์ €์žฅ๋œ predictions.json์„ ์ œ์ถœํ•ฉ๋‹ˆ๋‹ค.

wandb ์ ์šฉํ•˜๊ธฐ

  1. train.pyํŒŒ์ผ์„ ์ˆ˜ํ–‰ํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํ™”๋ฉด์—์„œ 2๋ฒˆ ์„ ํƒ

image

  1. API key๋ฅผ ๋ฐ›์„ ์ˆ˜ ์žˆ๋Š” ๋งํฌ๋กœ ๋“ค์–ด๊ฐ€ (๊ทธ๋ฆผ 2๋ฒˆ์งธ ์ค„) ๊ณต์œ  ๊ณ„์ •์œผ๋กœ ๋กœ๊ทธ์ธ

image

  1. ์•„๋ž˜์™€ ๊ฐ™์ด key๊ฐ’์„ terminal ์ฐฝ์— ๋ณต์‚ฌ ๋ถ™์—ฌ ๋„ฃ๊ธฐ

    • ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ˆ˜ํ–‰๋จ

    image

  2. wandb ํ™ˆํŽ˜์ด์ง€์—์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋งŒ๋“ค์–ด์ง„ project๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Œ

image

image