我尝试以更清晰,更面向对象的方式去实现,看上去似乎是常数问题?
class Node:
def __init__(self, l, r, m):
self.l = l
self.r = r
self.len = r - l + 1
self.sum = 0
self.add = 0
self.mul = 1
self.m = m
self.ls = None
self.rs = None
def in_range(self, l, r):
return l <= self.l and self.r <= r
def out_range(self, l, r):
return self.r < l or self.l > r
def push_up(self):
self.sum = (self.ls.sum + self.rs.sum) % self.m
def apply(self, add, mul):
self.sum = (self.sum * mul + add * self.len) % self.m
self.mul = (self.mul * mul) % self.m
self.add = (self.add * mul + add) % self.m
def push_down(self):
if (self.add != 0 or self.mul != 1) and self.ls is not None:
self.ls.apply(self.add, self.mul)
self.rs.apply(self.add, self.mul)
self.mul = 1
self.add = 0
class SegTree:
def __init__(self, nums, m):
self.n = len(nums)
if self.n == 0:
return
self._nums = nums
self.root = self._build(0, self.n - 1, m)
def _build(self, l, r, p):
node = Node(l, r, p)
if l == r:
node.sum = self._nums[l] % p
return node
m = (l + r) // 2
node.ls = self._build(l, m, p)
node.rs = self._build(m+1, r, p)
node.push_up()
return node
def add(self, l, r, val):
self._update(l, r, val, 1, self.root)
def mul(self, l, r, val):
self._update(l, r, 0, val, self.root)
def _update(self, l, r, add, mul, node:Node):
if node.in_range(l, r):
node.apply(add, mul)
elif not node.out_range(l, r):
node.push_down()
self._update(l, r, add, mul, node.ls)
self._update(l, r, add, mul, node.rs)
node.push_up()
def query(self, l, r):
return self._query(l, r, self.root)
def _query(self, l, r, node:Node):
if node.in_range(l, r):
return node.sum % node.m
elif node.out_range(l, r):
return 0
else:
node.push_down()
return (self._query(l, r, node.ls) + self._query(l, r, node.rs)) % node.m
import sys
input = sys.stdin.readline
n, m, p = list(map(int, input().split()))
a = list(map(int, input().split()))
seg = SegTree(a, p)
for _ in range(m):
op = list(map(int, input().split()))
if op[0] == 1:
x, y, k = op[1:]
seg.mul(x-1, y-1, k)
elif op[0] == 2:
x, y, k = op[1:]
seg.add(x-1, y-1, k)
else:
x, y = op[1:]
ans = seg.query(x-1, y-1)
sys.stdout.write(str(ans) + "\n")