用avl做的
#include<iostream>
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
using namespace std;
typedef struct node {
struct node* rchild;
struct node* lchild;
int data;
int height;
}Node;
int getheight(Node* root)
{
if (!root)
return 0;
else
return root->height;
}
Node* rotate_l(Node* root)
{
Node* x = root->rchild->lchild;
Node* y = root->rchild;
root->rchild->lchild = root;
root->rchild = x;
root->height = getheight(root->lchild) + 1;
return y;
}
Node* rotate_r(Node* root)
{
Node* x = root->lchild->rchild;
Node* y = root->lchild;
root->lchild->rchild = root;
root->lchild = x;
root->height = getheight(root->rchild) + 1;
return y;
}
Node* insert(int d, Node* root)
{
if (!root) {
root = new Node();
root->data = d;
root->lchild = root->rchild = nullptr;
}
else if (d < root->data) {
root->lchild = insert(d, root->lchild);
if (getheight(root->lchild) - getheight(root->rchild)==2) {
if (getheight(root->lchild->rchild) > getheight(root->lchild->lchild)) {
root->lchild = rotate_l(root->lchild);
}
root = rotate_r(root);
}
}
else if (d > root->data) {
root->rchild = insert(d, root->rchild);
if (getheight(root->rchild) - getheight(root->lchild) == 2) {
if (getheight(root->rchild->lchild) > getheight(root->rchild->rchild)) {
root->rchild = rotate_r(root->rchild);
}
root = rotate_l(root);
}
}
root->height = max(getheight(root->lchild), getheight(root->rchild)) + 1;
return root;
}
int search(Node* root,int d,int pred,int lr)
{
Node* x = root;
if (x->data == d)
return 0;
if (d < x->data)
{
if (!x->lchild) {
if (lr==1)
return min(abs(pred-d), abs(d-x->data));
else
return abs(d - x->data);
}
else return search(root->lchild, d, root->data, 0);
}
if (d > x->data)
{
if (!x->rchild) {
if (!lr)
return min(abs(d - pred), abs(x->data - d));
else
return abs(x->data - d);
}
else return search(root->rchild, d, root->data, 1);
}
}
inline int read()
{
int num = 0, w = 1; char ch = 0;
while (ch < '0' || ch>'9') { if (ch == '-') w = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') num = (num << 3) + (num << 1) + ch - '0', ch = getchar();
return num * w;
}
int main()
{
int n;
Node* root=nullptr;
n = read();
int x, sum = 0;
x = read();
root = insert(x, root);
sum += x;
for (int i = 0; i < n - 1; i++)
{
x = read();
sum += search(root, x, root->data, 2);
root = insert(x, root);
}
cout << sum;
}