Open-vocab based DST model์ธ TRADE์ ํ๊ตญ์ด ๊ตฌํ์ฒด์
๋๋ค. (5๊ฐ, 6๊ฐ ๋ด์ฉ ์ฐธ๊ณ )
- ๊ธฐ์กด์ GloVe, Char Embedding ๋์
monologg/koelectra-base-v3-discriminator
์token_embeddings
์pretrained Subword Embedding์ผ๋ก ์ฌ์ฉํฉ๋๋ค. - ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๋ผ๊ธฐ ์ํด Token Embedding (768) => Hidden Dimension (400)์ผ๋ก์ Projection layer๊ฐ ๋ค์ด ์์ต๋๋ค.
- ๋น ๋ฅธ ํ์ต์ ์ํด
Parallel Decoding
์ด ๊ตฌํ๋์ด ์์ต๋๋ค.
pip install -r requirements.txt
SM_CHANNEL_TRAIN=data/train_dataset SM_MODEL_DIR=[model saving dir] python train.py
ํ์ต๋ ๋ชจ๋ธ์ epoch ๋ณ๋ก SM_MODEL_DIR/model-{epoch}.bin
์ผ๋ก ์ ์ฅ๋ฉ๋๋ค.
์ถ๋ก ์ ํ์ํ ๋ถ๊ฐ ์ ๋ณด์ธ configuration๋ค๋ ๊ฐ์ ๊ฒฝ๋ก์ ์ ์ฅ๋ฉ๋๋ค.
Best Checkpoint Path๊ฐ ํ์ต ๋ง์ง๋ง์ ํ๊ธฐ๋ฉ๋๋ค.
SM_CHANNEL_EVAL=data/eval_dataset/public SM_CHANNEL_MODEL=[Model Checkpoint Path] SM_OUTPUT_DATA_DIR=[Output path] python inference.py
3๋ฒ ์คํ
inference.py
์์ SM_OUTPUT_DATA_DIR
์ ์ ์ฅ๋ predictions.json
์ ์ ์ถํฉ๋๋ค.
- train.pyํ์ผ์ ์ํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ ํ๋ฉด์์ 2๋ฒ ์ ํ
- API key๋ฅผ ๋ฐ์ ์ ์๋ ๋งํฌ๋ก ๋ค์ด๊ฐ (๊ทธ๋ฆผ 2๋ฒ์งธ ์ค) ๊ณต์ ๊ณ์ ์ผ๋ก ๋ก๊ทธ์ธ