[이론] LLM GPU 학습 병렬처리 (DP, DDP, FSDP)

Topic. 출처 : 원문

1. DP (Data Parallel)

각 GPU 에 서로 다른 데이터 분산처리방식

import torch.nn as nn
model = nn.DataParallel(model)

model 을 pytorch의 nn 라이브러리의 DataParallel 클래스로 감싸준다.

- GPU1, GPU2, GPU3, GPU4 에서 GPU1이 메인 GPU라면 GPU1에 과부하가 발생

 

2. DDP (Distributed Data Parallel)

멀티쓰레딩, sampler 활용한 데이터 분산처리방식

from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model, device_ids=[rank])

multithreading 을 활용한 Data Parallel 방식

- 메인 GPU 과부하 문제 해결

sampler를 통해 각 GPU에 서로 다른 데이터가 전송되며, 각 데이터를 이용해서 모델 파라미터의 gradients A, B, C, D를 계산합니다. 이후 All Reduce 연산을 통해 gradients A, B, C, D에 대한 평균을 구한 뒤, 모든 GPU에 전달됩니다. 이후 optimizer의 step을 통해 각 GPU에서 모델 파라미터가 업데이트 되고, 똑같은 gradients 값을 사용했기 때문에, 똑같은 모델 정보가 보장됩니다.

 

3. FSDP (Fully Sharded Data Parallel)

모델의 정보와 데이터를 분산시켜 처리하는 방식

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, auto_wrap_policy=wrap_policy, device_id=rank)

반면 FSDP에서는 모델의 모든 정보가 하나의 GPU에 있는 것이 아니라, 여러 GPU에 분산되어(sharded) 있습니다. 따라서 forward 과정에서 모델의 각 layer를 통과할 때마다 다른 GPU에 저장되어 있는 파라미터를 가져와 사용하고 제거합니다 (All Gather 연산). 이후 backward 과정에서 다시 gradients를 계산하기 위해 다른 GPU에 저장되어 있는 파라미터를 가져와서 사용하고 (All Gather 연산), 각 GPU에서 계산된 gradients를 다시 원래 속해 있던 GPU에 전달하기 위해서 Reduce Scatter 연산을 사용합니다. 최종적으로 각 GPU에는 각 GPU가 갖고 있던 모델에 대한 gradients만 남기 때문에, 이후 optimizer의 step 연산을 통해 모델의 파라미터를 업데이트할 수 있습니다. (각 연산에 대한 자세한 설명은 NCCL 문서를 참고해 주세요.)

 

  • 네이버 블로그 공유
  • 네이버 밴드 공유
  • 페이스북 공유