m,n=map(int,input().split())
xx = list(map(int,input().split()))#学校
xs = list(map(int,input().split()))#学生
xs.sort()#排序
xx.sort()
s=0
for i in xs:#学生
ans=999999
left=0
right=len(xx)-1
while(left<=right):
mid = (left+right)//2
if(xx[mid]==i):
ans=0
break
elif(xx[mid]<i):
ans = min(ans,abs(xx[mid]-i))#最小值
left = mid+1
elif(xx[mid]>i):
ans = min(ans,abs(xx[mid]-i))
right = mid-1
s=s+ans
print(s)