python70分TLE求助
查看原帖
python70分TLE求助
1022761
wuuk楼主2025/6/11 23:25

我尝试以更清晰,更面向对象的方式去实现,看上去似乎是常数问题?

class Node:
    def __init__(self, l, r, m):
        self.l = l
        self.r = r
        self.len = r - l + 1
        self.sum = 0
        self.add = 0
        self.mul = 1
        self.m = m
        self.ls = None
        self.rs = None
    
    def in_range(self, l, r):
        return l <= self.l and self.r <= r
    
    def out_range(self, l, r):
        return self.r < l or self.l > r

    def push_up(self):
        self.sum = (self.ls.sum + self.rs.sum) % self.m
    
    def apply(self, add, mul):
        self.sum = (self.sum * mul + add * self.len) % self.m
        self.mul = (self.mul * mul) % self.m
        self.add = (self.add * mul + add) % self.m
    
    def push_down(self):
        if (self.add != 0 or self.mul != 1) and self.ls is not None:
            self.ls.apply(self.add, self.mul)
            self.rs.apply(self.add, self.mul)
            self.mul = 1
            self.add = 0

class SegTree:
    def __init__(self, nums, m):
        self.n = len(nums)
        if self.n == 0:
            return
        self._nums = nums
        self.root = self._build(0, self.n - 1, m)
    
    def _build(self, l, r, p):
        node = Node(l, r, p)
        if l == r:
            node.sum = self._nums[l] % p
            return node
        m = (l + r) // 2
        node.ls = self._build(l, m, p)
        node.rs = self._build(m+1, r, p)
        node.push_up()
        return node
    
    def add(self, l, r, val):
        self._update(l, r, val, 1, self.root)
    
    def mul(self, l, r, val):
        self._update(l, r, 0, val, self.root)
    
    def _update(self, l, r, add, mul, node:Node):
        if node.in_range(l, r):
            node.apply(add, mul)
        elif not node.out_range(l, r):
            node.push_down()
            self._update(l, r, add, mul, node.ls)
            self._update(l, r, add, mul, node.rs)
            node.push_up()
    
    def query(self, l, r):
        return self._query(l, r, self.root)

    def _query(self, l, r, node:Node):
        if node.in_range(l, r):
            return node.sum % node.m
        elif node.out_range(l, r):
            return 0
        else:
            node.push_down()
            return (self._query(l, r, node.ls) + self._query(l, r, node.rs)) % node.m

import sys
input = sys.stdin.readline

n, m, p = list(map(int, input().split()))
a = list(map(int, input().split()))
seg = SegTree(a, p)
for _ in range(m):
    op = list(map(int, input().split()))
    if op[0] == 1:
        x, y, k = op[1:]
        seg.mul(x-1, y-1, k)
    elif op[0] == 2:
        x, y, k = op[1:]
        seg.add(x-1, y-1, k)
    else:
        x, y = op[1:]
        ans = seg.query(x-1, y-1)
        sys.stdout.write(str(ans) + "\n")
2025/6/11 23:25
加载中...