n,k=map(int,input().split()) s=n;a=n i=0 while i<=s:
if i%k==0 and i!=0: s+=1 a=a-i+1 i+=1 if a==0: break
print(s)