大佬们看看吧,只对了1 2 10
查看原帖
大佬们看看吧,只对了1 2 10
412257
q3q4q5楼主2021/2/10 19:57

用avl做的

#include<iostream>

#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))

using namespace std;

typedef struct node {
	struct node* rchild;
	struct node* lchild;
	int data;
	int height;
}Node;

int getheight(Node* root)
{
	if (!root)
		return 0;
	else 
		return root->height;
}

Node* rotate_l(Node* root)
{
	Node* x = root->rchild->lchild;
	Node* y = root->rchild;
	root->rchild->lchild = root;
	root->rchild = x;
	root->height = getheight(root->lchild) + 1;
	return y;
}

Node* rotate_r(Node* root)
{
	Node* x = root->lchild->rchild;
	Node* y = root->lchild;
	root->lchild->rchild = root;
	root->lchild = x;
	root->height = getheight(root->rchild) + 1;
	return y;
}

Node* insert(int d, Node* root)
{
	if (!root) {
		root = new Node();
		root->data = d;
		root->lchild = root->rchild = nullptr;
	}
	else if (d < root->data) {
		root->lchild = insert(d, root->lchild);
		if (getheight(root->lchild) - getheight(root->rchild)==2) {
			if (getheight(root->lchild->rchild) > getheight(root->lchild->lchild)) {
				root->lchild = rotate_l(root->lchild);
			}
			root = rotate_r(root);
		}
	}
	else if (d > root->data) {
		root->rchild = insert(d, root->rchild);
		if (getheight(root->rchild) - getheight(root->lchild) == 2) {
			if (getheight(root->rchild->lchild) > getheight(root->rchild->rchild)) {
				root->rchild = rotate_r(root->rchild);
			}
			root = rotate_l(root);
		}
	}
	root->height = max(getheight(root->lchild), getheight(root->rchild)) + 1;
	return root;
}

int search(Node* root,int d,int pred,int lr)
{
	Node* x = root;
	if (x->data == d)
		return 0;
	if (d < x->data)
	{
		if (!x->lchild) {
			if (lr==1)
				return min(abs(pred-d), abs(d-x->data));
			else
				return abs(d - x->data);
		}
		else return search(root->lchild, d, root->data, 0);
	}
	if (d > x->data)
	{
		if (!x->rchild) {
			if (!lr)
				return min(abs(d - pred), abs(x->data - d));
			else
				return abs(x->data - d);
		}
		else return search(root->rchild, d, root->data, 1);
	}
}

inline int read()
{
	int num = 0, w = 1; char ch = 0;
	while (ch < '0' || ch>'9') { if (ch == '-') w = -1; ch = getchar(); }
	while (ch >= '0' && ch <= '9') num = (num << 3) + (num << 1) + ch - '0', ch = getchar();
	return num * w;
}

int main()
{
	int n;
	Node* root=nullptr;
	n = read();
	int x, sum = 0;
	x = read();
	root = insert(x, root);
	sum += x;
	for (int i = 0; i < n - 1; i++)
	{
		x = read();
		sum += search(root, x, root->data, 2);
		root = insert(x, root);
	}
	cout << sum;
}
2021/2/10 19:57
加载中...