这周主要讲了一些根号算法。

1126 骆源的哈夫曼树

题意:给出 n(n105)n(n\le 10 ^{5}) 个数据,构建一棵 m(m103)m(m\le 10 ^{3}) 的哈夫曼树,并给出其 WPL(带权路径长度,即所有编码长度乘上权值的和)。

解析:就是裸的哈夫曼树题,具体实现方式就是贪心,将较小权值的数据放到更深的地方。首先将所有的权值加入小根堆中,选出最小的 mm 个权值取出,求和后再放回堆中,实际上就是给这 mm 个权值建立一个父节点,这时将这 mm 个权值的和加到答案中,算作这一深度的贡献。因此整体只需要一个小根堆就能实现。唯一的细节问题是需要注意补零的情况:对于 nn 个数据的哈夫曼树,每次取出 mm 个放回 11 个,实际每次减少了 m1m-1 个,最后要在堆中剩下一个作为根节点,因此需要保证 (m1)(n1)(m-1)|(n-1)。所以直接补零补到 (m1)(n1)(m-1)|(n-1) 为止就行了。

代码

//
// Created by 屋顶上的小丑 on 2023/3/26.
//
#include<cstdio>
#include<iostream>
const int maxn=1e5;
const int maxm=1e3;
int n,m,cnt;
int heap[(maxn<<1)+5],f[maxn+5];
long long ans;
void add(long long x)
{
    heap[++cnt]=x;
    int now=cnt;
    while(now>1)
    {
        int nxt=now>>1;
        if(heap[nxt]>heap[now])
            std::swap(heap[nxt],heap[now]);
        else
            return;
        now=nxt;
    }
}
void del()
{
    std::swap(heap[cnt],heap[1]);
    cnt--;
    int now=1;
    while((now<<1)<=cnt)
    {
        int nxt=now<<1;
        if(nxt+1<=cnt&&heap[nxt]>heap[nxt+1])
            nxt++;
        if(heap[nxt]<heap[now])
            std::swap(heap[nxt],heap[now]);
        else
            return;
        now=nxt;
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&f[i]);
        add(f[i]);
    }
    while((n-1)%(m-1))
    {
        add(0ll);
        n++;
    }
    while(cnt>=m)
    {
        int t=m;
        long long now=0;
        while(t--)
        {
            now+=heap[1];
            del();
        }
        ans+=now;
        add(now);
    }
    printf("%lld",ans);
    return 0;
}

1822 某科学的加密算法

题意:有 Q(Q100)Q(Q\le 100) 个人分别收到了 Ti(Ti2×105)T _{i}(\sum T _{i}\le2\times 10 ^{5}) 条由它们的公钥加密的密文,现在给出密文与公钥,请你破解出明文。

公钥加密算法如下:

  • 密钥生成流程
    • 随机生成一个 κ(κ31)\kappa(\kappa \le 31) 比特的素数 p(κ=max{log2pi+1})p(\kappa = \max \{\lfloor \log _{2} p _{i}\rfloor+1\}),满足 p=2q+1p=2q+1,其中 qq 也是素数;
    • 找到 pp 的一个原根 aa,令 g=a2modpg= a ^{2}\bmod p
    • 均匀随机生成一个 1q11\sim q-1 的整数 xx,令 y=gxmodpy=g ^{x}\bmod p
    • 生成公钥 pk=(q,g,y)pk=(q,g,y),私钥 sk=xsk=x
  • 加密流程(明文 mm 满足存在 kZk\in \Zm=gkmodpm= g ^{k}\bmod p
    • 均匀随机生成一个 1q11\sim q-1 的整数 rr
    • c1=grmodp,c2=yrmmodpc _{1}=g ^{r}\bmod p,c _{2}=y ^{r}m\bmod p,密文即为 (c1,c2)(c _{1},c _{2})

解析:裸的 BSGS 题。

易知 c2=yrmmodp=gxrmmodp=c1xmmodpc _{2}=y ^{r}m\bmod p=g ^{xr}m\bmod p=c _{1} ^{x}m\bmod p,因此 mc2c1x(modp)m\equiv \dfrac{c _{2}}{c _{1} ^{x}}\pmod p,我们只需要解出 xx 即可。

又因为 y=gxmodpy=g ^{x}\bmod p,因此 xloggy(modp)x\equiv \log _{g}y\pmod p,用 BSGS 解出即可。

代码

//
// Created by 屋顶上的小丑 on 2023/3/27.
//
#include<cstdio>
#include<cmath>
const int mod=1e6+7;
int Q,T;
long long q,g,y;
long long c1,c2;
struct hash_map
{
    struct data
    {
        long long u;
        int v;
        int nex;
    }e[(mod<<1)+5];
    int head[mod+5],cnt,tim[(mod<<1)+5];
    int hash(long long u) const { return u%mod; }
    int find(long long u) const
    {
        int fir=hash(u);
        for(int i=head[fir];i;i=e[i].nex)
            if(e[i].u==u)
                return e[i].v;
        return -1;
    }
    int& operator[](long long u)
    {
        int fir=hash(u);
        for(int i=head[fir];i;i=e[i].nex)
            if(e[i].u==u)
                return e[i].v;
        e[++cnt]=(data){u,-1,head[fir]};
        tim[cnt]=fir;
        head[fir]=cnt;
        return e[cnt].v;
    }
    void clear()
    {
        for(int i=1;i<=cnt;i++)
            head[tim[i]]=0;
        cnt=0;
    }
};
long long power(long long a,long long b,long long p)
{
    long long ans=1;
    for(;b;b>>=1)
    {
        if(b&1)
            ans=ans*a%p;
        a=a*a%p;
    }
    return ans%p;
}
hash_map hash;
long long BSGS(long long a,long long b,long long p)
{
    hash.clear();
    b%=p;
    if(!a)
    {
        if(!b)
            return 1;
        else
            return -1;
    }
    int t=sqrt(p)+1;
    long long val=b;
    hash[val]=0;
    for(int j=1;j<t;j++)
    {
        val=val*a%p;
        hash[val]=j;
    }
    long long num=power(a,t,p)%p;
    val=1;
    for(int i=1;i<=t;i++)
    {
        val=val*num%p;
        int j=hash.find(val);
        if(j>=0&&i*t-j>=0)
            return i*t-j;
    }
    return -1;
}
int main()
{
    scanf("%d",&Q);
    while(Q--)
    {
        scanf("%lld%lld%lld%d",&q,&g,&y,&T);
        long long p=(q<<1)+1;
        long long x=BSGS(g,y,p);
        while(T--)
        {
            scanf("%lld%lld",&c1,&c2);
            printf("%lld ",c2*power(c1,x*(p-2)%(p-1),p)%p);
        }
        printf("\n");
    }
    return 0;
}

1823 条件糖果甜度问题

题意:有一个长度为 n(n5×104)n(n\le 5\times 10 ^{4}) 的序列 si(1sin)s _{i}(1\le s _{i}\le n),给出 Q(Q2.5×105)Q(Q\le 2.5\times 10 ^{5}) 个询问,每次询问 [l,r][l,r] 中有多少个不同的 sis _{i}[a,b][a,b] 之间。

解析:原题大概是这个[AHOI2013] 作业

非常典型的莫队题,我们能够维护 [l,r][l,r] 中每种 sis _{i} 具体出现了多少次,然后就能判断它出没出现过,现在考虑有多少个不同在 [a,b][a,b] 之间的 sis _{i} 如何维护。

容易想到可以用一个权值树状数组来判断,每次如果有一个新的 sis _{i} 加进来就插进树状数组中,如果有一个 sis _{i} 的出现次数变成 00 就把它从树状数组中删掉,此时虽然单次查询只需要 O(logn)O(\log n) 的复杂度,但单次移动区间时也会将复杂度会带上一个 log\log,因此总复杂度大概为 O(Qnlogn)O(Q\sqrt{n}\log n)

考虑能不能把这个 log\log 给去掉。我们可以通过分块来维护上面的信息。对值域 [1,maxsi][1,\max s _{i}] 分块,这样莫队移动区间时就能够做到 O(1)O(1) 修改,查询时也是 O(n)O(\sqrt{n}) 的复杂度,总复杂度为 O(Qn)O(Q\sqrt{n})

代码

//
// Created by 屋顶上的小丑 on 2023/3/26.
//
#include<cstdio>
#include<cmath>
const int maxn=5e4;
const int maxq=2.5e5;
int n,Q,len,maxx;
int s[maxn+5],bl[maxn+5];
int val[maxn+5],cnt[maxn+5];
int bel[maxn+5],ans[maxq+5];
struct ques
{
    int l,r;
    int a,b;
    int id;
}q[maxq+5],tmp[maxq+5];
bool cmp(ques x,ques y)
{
    if(bl[x.l]==bl[y.l])
    {
        if(bl[x.l]&1)
            return x.r<y.r;
        return x.r>y.r;
    }
    return x.l<y.l;
}
void merge_sort(ques* x,int l,int r)
{
    if(l==r)
        return ;
    int mid=(l+r)>>1;
    merge_sort(x,l,mid);
    merge_sort(x,mid+1,r);
    int s1=l,s2=mid+1,s3=0;
    while(s1<=mid&&s2<=r)
    {
        if(cmp(q[s1],q[s2]))
        {
            tmp[++s3]=q[s1];
            s1++;
        }
        else
        {
            tmp[++s3]=q[s2];
            s2++;
        }
    }
    for(;s1<=mid;s1++) tmp[++s3]=q[s1];
    for(;s2<=r;s2++) tmp[++s3]=q[s2];
    for(int i=l;i<=r;i++)
        q[i]=tmp[i-l+1];
}
void add(int x)
{
    val[s[x]]++;
    if(val[s[x]]==1)
        cnt[bel[s[x]]]++;
}
void del(int x)
{
    val[s[x]]--;
    if(!val[s[x]])
        cnt[bel[s[x]]]--;
}
int query(int l,int r)
{
    if(l>maxx) return 0;
    int L=bel[l],R=bel[r],ret=0;
    if(L==R)
    {
        for(int i=l;i<=r;i++)
            ret+=(bool)val[i];
        return ret;
    }
    for(int i=L+1;i<R;i++)
        ret+=cnt[i];
    for(int i=l;i<=maxx&&bel[i]==L;i++)
        ret+=(bool)val[i];
    for(int i=r;i>=1&&bel[i]==R;i--)
        ret+=(bool)val[i];
    return ret;
}
int main()
{
    scanf("%d%d",&n,&Q);
    len=n/sqrt(Q*2.0/3.0);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&s[i]);
        if(s[i]>maxx) maxx=s[i];
        bl[i]=(i-1)/len+1;
    }
    len=sqrt(maxx);
    if(!len) len=1;
    for(int i=1;i<=maxx;i++)
        bel[i]=(i-1)/len+1;
    for(int i=1;i<=Q;i++)
    {
        scanf("%d%d%d%d",&q[i].l,&q[i].r,&q[i].a,&q[i].b);
        if(q[i].b>maxx) q[i].b=maxx;
        q[i].id=i;
    }
    merge_sort(q,1,Q);
    int L=1,R=0;
    for(int i=1;i<=Q;i++)
    {
        while(L>q[i].l) add(--L);
        while(R<q[i].r) add(++R);
        while(L<q[i].l) del(L++);
        while(R>q[i].r) del(R--);
        ans[q[i].id]=query(q[i].a,q[i].b);
    }
    for(int i=1;i<=Q;i++)
        printf("%d\n",ans[i]);
    return 0;
}

1824 动态调整的势函数

题意:有一棵 n(n105)n(n\le 10 ^{5}) 的有根树,每个点都有一个权值 wi(wi109)w _{i}(w _{i}\le 10 ^{9}),定义 sis _{i} 为以 ii 为根节点的子树权值和(即 si=xsubtree(i)wxs _{i}=\displaystyle\sum _{x\in\text{subtree}(i)}w _{x}),现在有 Q(Q105)Q(Q\le 10 ^{5}) 个操作如下:

  • 求出编号为 [l,r][l,r] 的节点的 sis _{i} 之和,即 i=lrsi\displaystyle\sum _{i=l} ^{r}s _{i}
  • 改变某个点的权值 wiw _{i}

解析:考虑分块求解。

对编号分块,维护 sis _{i} 的区间和,这样查询时整块可以直接利用维护的值,但对于整块之外的散点怎么快速求出 sis _{i} 呢?要快速求出子树权值和,我们可以通过欧拉序将它转为区间问题,这样求的就是某个区间上的 wiw _{i} 之和,用树状数组维护即可,单次查询复杂度为 O(nlogn)O(\sqrt{n}\log n)

现在来到修改操作,显然对于散点,直接在树状数组上修改即可,但是对于整块则不太好处理。我们考虑修改 wiw _{i} 后会影响哪些节点,易知从 ii 到根节点的这条链上的节点的 sis _{i} 会受到影响,但这些节点的编号可能并不连续,它们可能属于不同的块中,如果我们直接往上扫一遍来更新复杂度肯定是无法接受的。

事实上,对于整块而言,我们并不关心其中哪些节点的 sis _{i} 发生了变化,我们只需知道有多少个这个块中的 sis _{i} 发生了变化便足以维护这个块的 sis _{i} 之和,而这个数量我们可以直接在 dfs 的时候顺便就维护。用一个数组 num0,inum _{0,i} 记录遍历到当前节点 uu 时第 ii 个块中有多少个节点在 uu 到根节点的链上,然后用 numu,inum _{u,i} 将此时的数组拷贝一遍记录下来,即可得到 uu 到根节点的路径上每个块有多少个节点。修改时直接把所有块扫一遍加上这个块的贡献即可。此时单次修改复杂度为 O(n)O(\sqrt{n})。故总复杂度为 O(Qnlogn)O(Q\sqrt{n}\log n),如果能够改变块长和使用 vector 能降到 O(Qnlogn)O(Q\sqrt{n\log n})

依然可以用类似上一题的方法去掉这个 log\log。不用树状数组维护,改为用分块。思考树状数组本质上维护的是什么,无非是求出的欧拉序上的 wiw _{i} 的前缀和,而这个前缀和我们也能通过分块来维护。对欧拉序分块,维护块内前缀和和块前缀和,将修改操作对整块的块前缀和的影响用一个标记数组记录(代表这个块都要加上这个值),就能够将单词查询复杂度降为 O(n)O(\sqrt{n})。此时总复杂度为 O(Qn)O(Q\sqrt{n})

代码

树状数组实现:

//
// Created by 屋顶上的小丑 on 2023/3/27.
//
#include<cstdio>
#include<cmath>
const int maxn=1e5;
const int maxlen=400;
int n,Q,l[maxn+5],r[maxn+5];
int len,bel[maxn+5],tot,root;
int num[maxn+5][maxlen+5],tim;
int head[maxn+5],cnt,w[maxn+5];
long long sum[maxn+5],tr[maxn+5];
long long siz[maxn+5];
struct node
{
    int nex;
    int to;
}e[(maxn<<1)+5];
void add(int from,int to)
{
    e[++cnt].nex=head[from];
    e[cnt].to=to;
    head[from]=cnt;
}
int lowbit(int x)
{
    return x&(-x);
}
void Add(int x,int val)
{
    for(int i=x;i<=n;i+=lowbit(i))
        tr[i]+=val;
}
long long Query(int x)
{
    long long ret=0;
    for(int i=x;i;i-=lowbit(i))
        ret+=tr[i];
    return ret;
}
void dfs(int u,int fa)
{
    l[u]=++tim;
    num[0][bel[u]]++;
    siz[u]=w[u];
    Add(l[u],w[u]);
    for(int i=1;i<=tot;i++)
        num[u][i]=num[0][i];
    for(int i=head[u];i;i=e[i].nex)
    {
        int v=e[i].to;
        if(v==fa)
            continue;
        dfs(v,u);
        siz[u]+=siz[v];
    }
    r[u]=tim;
    num[0][bel[u]]--;
    sum[bel[u]]+=siz[u];
}
int main()
{
    scanf("%d%d",&n,&Q);
    len=sqrt(n);
    tot=n/len+(n%len!=0);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&w[i]);
        bel[i]=(i-1)/len+1;
    }
    int u,v;
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d",&u,&v);
        if(!u)
            root=v;
        else
        {
            add(u,v);
            add(v,u);
        }
    }
    dfs(root,0);
    int op,x,y;
    for(int i=1;i<=Q;i++)
    {
        scanf("%d%d%d",&op,&x,&y);
        if(op==1)
        {
            int delta=y-w[x];
            for(int j=1;j<=tot;j++)
                sum[j]+=1ll*num[x][j]*delta;
            w[x]=y;
            Add(l[x],delta);
        }
        else
        {
            long long ans=0;
            int L=bel[x],R=bel[y];
            if(L==R)
            {
                for(int j=x;j<=y;j++)
                    ans+=Query(r[j])-Query(l[j]-1);
            }
            else
            {
                for(int j=x;bel[j]==L;j++)
                    ans+=Query(r[j])-Query(l[j]-1);
                for(int j=L+1;j<R;j++)
                    ans+=sum[j];
                for(int j=y;bel[j]==R;j--)
                    ans+=Query(r[j])-Query(l[j]-1);
            }
            printf("%lld\n",ans);
        }
    }
    return 0;
}

分块实现:

//
// Created by 屋顶上的小丑 on 2023/3/27.
//
#include<cstdio>
#include<cmath>
const int maxn=1e5;
const int maxlen=400;
int n,Q,l[maxn+5],r[maxn+5];
int len,bel[maxn+5],tot,root;
int num[maxn+5][maxlen+5],tim;
int head[maxn+5],cnt,w[maxn+5];
long long sum[maxn+5],siz[maxn+5];
long long s[maxn+5],tag[maxn+5];
struct node
{
    int nex;
    int to;
}e[(maxn<<1)+5];
void add(int from,int to)
{
    e[++cnt].nex=head[from];
    e[cnt].to=to;
    head[from]=cnt;
}
void dfs(int u,int fa)
{
    l[u]=++tim;
    num[0][bel[u]]++;
    siz[u]=s[tim]=w[u];
    for(int i=1;i<=tot;i++)
        num[u][i]=num[0][i];
    for(int i=head[u];i;i=e[i].nex)
    {
        int v=e[i].to;
        if(v==fa)
            continue;
        dfs(v,u);
        siz[u]+=siz[v];
    }
    r[u]=tim;
    num[0][bel[u]]--;
    sum[bel[u]]+=siz[u];
}
int main()
{
    scanf("%d%d",&n,&Q);
    len=sqrt(n);
    tot=n/len+(n%len!=0);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&w[i]);
        bel[i]=(i-1)/len+1;
    }
    int u,v;
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d",&u,&v);
        if(!u)
            root=v;
        else
        {
            add(u,v);
            add(v,u);
        }
    }
    dfs(root,0);
    for(int i=1;i<=n;i++)
        s[i]+=s[i-1];
    int op,x,y;
    for(int i=1;i<=Q;i++)
    {
        scanf("%d%d%d",&op,&x,&y);
        if(op==1)
        {
            int delta=y-w[x];
            for(int j=1;j<=tot;j++)
                sum[j]+=1ll*num[x][j]*delta;
            w[x]=y;
            int L=l[x];
            for(int j=L;bel[j]==bel[L];j++)
                s[j]+=delta;
            for(int j=bel[L]+1;j<=tot;j++)
                tag[j]+=delta;
        }
        else
        {
            long long ans=0;
            int L=bel[x],R=bel[y];
            if(L==R)
            {
                for(int j=x;j<=y;j++)
                    ans+=s[r[j]]+tag[bel[r[j]]]-s[l[j]-1]-tag[bel[l[j]-1]];
            }
            else
            {
                for(int j=x;bel[j]==L;j++)
                    ans+=s[r[j]]+tag[bel[r[j]]]-s[l[j]-1]-tag[bel[l[j]-1]];
                for(int j=L+1;j<R;j++)
                    ans+=sum[j];
                for(int j=y;bel[j]==R;j--)
                    ans+=s[r[j]]+tag[bel[r[j]]]-s[l[j]-1]-tag[bel[l[j]-1]];
            }
            printf("%lld\n",ans);
        }
    }
    return 0;
}