aboutsummaryrefslogtreecommitdiff
path: root/lib/CodeGen/ExpandMemCmp.cpp
blob: c5910c18d89bd9b0fd591ebddfc39bf549f09d00 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
//===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This pass tries to partially inline the fast path of well-known library
// functions, such as using square-root instructions for cases where sqrt()
// does not need to set errno.
//
//===----------------------------------------------------------------------===//

#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Target/TargetLowering.h"
#include "llvm/Target/TargetSubtargetInfo.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"

using namespace llvm;

#define DEBUG_TYPE "expandmemcmp"

STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
STATISTIC(NumMemCmpGreaterThanMax,
          "Number of memcmp calls with size greater than max size");
STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");

static cl::opt<unsigned> MemCmpNumLoadsPerBlock(
    "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
    cl::desc("The number of loads per basic block for inline expansion of "
             "memcmp that is only being compared against zero."));

namespace {


// This class provides helper functions to expand a memcmp library call into an
// inline expansion.
class MemCmpExpansion {
  struct ResultBlock {
    BasicBlock *BB = nullptr;
    PHINode *PhiSrc1 = nullptr;
    PHINode *PhiSrc2 = nullptr;

    ResultBlock() = default;
  };

  CallInst *const CI;
  ResultBlock ResBlock;
  const uint64_t Size;
  unsigned MaxLoadSize;
  uint64_t NumLoadsNonOneByte;
  const uint64_t NumLoadsPerBlock;
  std::vector<BasicBlock *> LoadCmpBlocks;
  BasicBlock *EndBlock;
  PHINode *PhiRes;
  const bool IsUsedForZeroCmp;
  const DataLayout &DL;
  IRBuilder<> Builder;
  // Represents the decomposition in blocks of the expansion. For example,
  // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
  // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
  // TODO(courbet): Involve the target more in this computation. On X86, 7
  // bytes can be done more efficiently with two overlaping 4-byte loads than
  // covering the interval with [{4, 0},{2, 4},{1, 6}}.
  struct LoadEntry {
    LoadEntry(unsigned LoadSize, uint64_t Offset)
        : LoadSize(LoadSize), Offset(Offset) {
      assert(Offset % LoadSize == 0 && "invalid load entry");
    }

    uint64_t getGEPIndex() const { return Offset / LoadSize; }

    // The size of the load for this block, in bytes.
    const unsigned LoadSize;
    // The offset of this load WRT the base pointer, in bytes.
    const uint64_t Offset;
  };
  SmallVector<LoadEntry, 8> LoadSequence;

  void createLoadCmpBlocks();
  void createResultBlock();
  void setupResultBlockPHINodes();
  void setupEndBlockPHINodes();
  Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
  void emitLoadCompareBlock(unsigned BlockIndex);
  void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
                                         unsigned &LoadIndex);
  void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex);
  void emitMemCmpResultBlock();
  Value *getMemCmpExpansionZeroCase();
  Value *getMemCmpEqZeroOneBlock();
  Value *getMemCmpOneBlock();

 public:
  MemCmpExpansion(CallInst *CI, uint64_t Size,
                  const TargetTransformInfo::MemCmpExpansionOptions &Options,
                  unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
                  unsigned NumLoadsPerBlock, const DataLayout &DL);

  unsigned getNumBlocks();
  uint64_t getNumLoads() const { return LoadSequence.size(); }

  Value *getMemCmpExpansion();
};

// Initialize the basic block structure required for expansion of memcmp call
// with given maximum load size and memcmp size parameter.
// This structure includes:
// 1. A list of load compare blocks - LoadCmpBlocks.
// 2. An EndBlock, split from original instruction point, which is the block to
// return from.
// 3. ResultBlock, block to branch to for early exit when a
// LoadCmpBlock finds a difference.
MemCmpExpansion::MemCmpExpansion(
    CallInst *const CI, uint64_t Size,
    const TargetTransformInfo::MemCmpExpansionOptions &Options,
    const unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
    const unsigned NumLoadsPerBlock, const DataLayout &TheDataLayout)
    : CI(CI),
      Size(Size),
      MaxLoadSize(0),
      NumLoadsNonOneByte(0),
      NumLoadsPerBlock(NumLoadsPerBlock),
      IsUsedForZeroCmp(IsUsedForZeroCmp),
      DL(TheDataLayout),
      Builder(CI) {
  assert(Size > 0 && "zero blocks");
  // Scale the max size down if the target can load more bytes than we need.
  size_t LoadSizeIndex = 0;
  while (LoadSizeIndex < Options.LoadSizes.size() &&
         Options.LoadSizes[LoadSizeIndex] > Size) {
    ++LoadSizeIndex;
  }
  this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex];
  // Compute the decomposition.
  uint64_t CurSize = Size;
  uint64_t Offset = 0;
  while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) {
    const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex];
    assert(LoadSize > 0 && "zero load size");
    const uint64_t NumLoadsForThisSize = CurSize / LoadSize;
    if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
      // Do not expand if the total number of loads is larger than what the
      // target allows. Note that it's important that we exit before completing
      // the expansion to avoid using a ton of memory to store the expansion for
      // large sizes.
      LoadSequence.clear();
      return;
    }
    if (NumLoadsForThisSize > 0) {
      for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
        LoadSequence.push_back({LoadSize, Offset});
        Offset += LoadSize;
      }
      if (LoadSize > 1) {
        ++NumLoadsNonOneByte;
      }
      CurSize = CurSize % LoadSize;
    }
    ++LoadSizeIndex;
  }
  assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
}

unsigned MemCmpExpansion::getNumBlocks() {
  if (IsUsedForZeroCmp)
    return getNumLoads() / NumLoadsPerBlock +
           (getNumLoads() % NumLoadsPerBlock != 0 ? 1 : 0);
  return getNumLoads();
}

void MemCmpExpansion::createLoadCmpBlocks() {
  for (unsigned i = 0; i < getNumBlocks(); i++) {
    BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
                                        EndBlock->getParent(), EndBlock);
    LoadCmpBlocks.push_back(BB);
  }
}

void MemCmpExpansion::createResultBlock() {
  ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
                                   EndBlock->getParent(), EndBlock);
}

// This function creates the IR instructions for loading and comparing 1 byte.
// It loads 1 byte from each source of the memcmp parameters with the given
// GEPIndex. It then subtracts the two loaded values and adds this result to the
// final phi node for selecting the memcmp result.
void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
                                               unsigned GEPIndex) {
  Value *Source1 = CI->getArgOperand(0);
  Value *Source2 = CI->getArgOperand(1);

  Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
  Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
  // Cast source to LoadSizeType*.
  if (Source1->getType() != LoadSizeType)
    Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
  if (Source2->getType() != LoadSizeType)
    Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());

  // Get the base address using the GEPIndex.
  if (GEPIndex != 0) {
    Source1 = Builder.CreateGEP(LoadSizeType, Source1,
                                ConstantInt::get(LoadSizeType, GEPIndex));
    Source2 = Builder.CreateGEP(LoadSizeType, Source2,
                                ConstantInt::get(LoadSizeType, GEPIndex));
  }

  Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
  Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);

  LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
  LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
  Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);

  PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);

  if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
    // Early exit branch if difference found to EndBlock. Otherwise, continue to
    // next LoadCmpBlock,
    Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
                                    ConstantInt::get(Diff->getType(), 0));
    BranchInst *CmpBr =
        BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
    Builder.Insert(CmpBr);
  } else {
    // The last block has an unconditional branch to EndBlock.
    BranchInst *CmpBr = BranchInst::Create(EndBlock);
    Builder.Insert(CmpBr);
  }
}

/// Generate an equality comparison for one or more pairs of loaded values.
/// This is used in the case where the memcmp() call is compared equal or not
/// equal to zero.
Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
                                            unsigned &LoadIndex) {
  assert(LoadIndex < getNumLoads() &&
         "getCompareLoadPairs() called with no remaining loads");
  std::vector<Value *> XorList, OrList;
  Value *Diff;

  const unsigned NumLoads =
      std::min(getNumLoads() - LoadIndex, NumLoadsPerBlock);

  // For a single-block expansion, start inserting before the memcmp call.
  if (LoadCmpBlocks.empty())
    Builder.SetInsertPoint(CI);
  else
    Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);

  Value *Cmp = nullptr;
  // If we have multiple loads per block, we need to generate a composite
  // comparison using xor+or. The type for the combinations is the largest load
  // type.
  IntegerType *const MaxLoadType =
      NumLoads == 1 ? nullptr
                    : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
  for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
    const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];

    IntegerType *LoadSizeType =
        IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);

    Value *Source1 = CI->getArgOperand(0);
    Value *Source2 = CI->getArgOperand(1);

    // Cast source to LoadSizeType*.
    if (Source1->getType() != LoadSizeType)
      Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
    if (Source2->getType() != LoadSizeType)
      Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());

    // Get the base address using a GEP.
    if (CurLoadEntry.Offset != 0) {
      Source1 = Builder.CreateGEP(
          LoadSizeType, Source1,
          ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
      Source2 = Builder.CreateGEP(
          LoadSizeType, Source2,
          ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
    }

    // Get a constant or load a value for each source address.
    Value *LoadSrc1 = nullptr;
    if (auto *Source1C = dyn_cast<Constant>(Source1))
      LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
    if (!LoadSrc1)
      LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);

    Value *LoadSrc2 = nullptr;
    if (auto *Source2C = dyn_cast<Constant>(Source2))
      LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
    if (!LoadSrc2)
      LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);

    if (NumLoads != 1) {
      if (LoadSizeType != MaxLoadType) {
        LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
        LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
      }
      // If we have multiple loads per block, we need to generate a composite
      // comparison using xor+or.
      Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
      Diff = Builder.CreateZExt(Diff, MaxLoadType);
      XorList.push_back(Diff);
    } else {
      // If there's only one load per block, we just compare the loaded values.
      Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
    }
  }

  auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
    std::vector<Value *> OutList;
    for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
      Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
      OutList.push_back(Or);
    }
    if (InList.size() % 2 != 0)
      OutList.push_back(InList.back());
    return OutList;
  };

  if (!Cmp) {
    // Pairwise OR the XOR results.
    OrList = pairWiseOr(XorList);

    // Pairwise OR the OR results until one result left.
    while (OrList.size() != 1) {
      OrList = pairWiseOr(OrList);
    }
    Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
  }

  return Cmp;
}

void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
                                                        unsigned &LoadIndex) {
  Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);

  BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
                           ? EndBlock
                           : LoadCmpBlocks[BlockIndex + 1];
  // Early exit branch if difference found to ResultBlock. Otherwise,
  // continue to next LoadCmpBlock or EndBlock.
  BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
  Builder.Insert(CmpBr);

  // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
  // since early exit to ResultBlock was not taken (no difference was found in
  // any of the bytes).
  if (BlockIndex == LoadCmpBlocks.size() - 1) {
    Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
    PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
  }
}

// This function creates the IR intructions for loading and comparing using the
// given LoadSize. It loads the number of bytes specified by LoadSize from each
// source of the memcmp parameters. It then does a subtract to see if there was
// a difference in the loaded values. If a difference is found, it branches
// with an early exit to the ResultBlock for calculating which source was
// larger. Otherwise, it falls through to the either the next LoadCmpBlock or
// the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
// a special case through emitLoadCompareByteBlock. The special handling can
// simply subtract the loaded values and add it to the result phi node.
void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
  // There is one load per block in this case, BlockIndex == LoadIndex.
  const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];

  if (CurLoadEntry.LoadSize == 1) {
    MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex,
                                              CurLoadEntry.getGEPIndex());
    return;
  }

  Type *LoadSizeType =
      IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
  Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
  assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");

  Value *Source1 = CI->getArgOperand(0);
  Value *Source2 = CI->getArgOperand(1);

  Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
  // Cast source to LoadSizeType*.
  if (Source1->getType() != LoadSizeType)
    Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
  if (Source2->getType() != LoadSizeType)
    Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());

  // Get the base address using a GEP.
  if (CurLoadEntry.Offset != 0) {
    Source1 = Builder.CreateGEP(
        LoadSizeType, Source1,
        ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
    Source2 = Builder.CreateGEP(
        LoadSizeType, Source2,
        ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
  }

  // Load LoadSizeType from the base address.
  Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
  Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);

  if (DL.isLittleEndian()) {
    Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
                                                Intrinsic::bswap, LoadSizeType);
    LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
    LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
  }

  if (LoadSizeType != MaxLoadType) {
    LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
    LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
  }

  // Add the loaded values to the phi nodes for calculating memcmp result only
  // if result is not used in a zero equality.
  if (!IsUsedForZeroCmp) {
    ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
    ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
  }

  Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
  BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
                           ? EndBlock
                           : LoadCmpBlocks[BlockIndex + 1];
  // Early exit branch if difference found to ResultBlock. Otherwise, continue
  // to next LoadCmpBlock or EndBlock.
  BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
  Builder.Insert(CmpBr);

  // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
  // since early exit to ResultBlock was not taken (no difference was found in
  // any of the bytes).
  if (BlockIndex == LoadCmpBlocks.size() - 1) {
    Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
    PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
  }
}

// This function populates the ResultBlock with a sequence to calculate the
// memcmp result. It compares the two loaded source values and returns -1 if
// src1 < src2 and 1 if src1 > src2.
void MemCmpExpansion::emitMemCmpResultBlock() {
  // Special case: if memcmp result is used in a zero equality, result does not
  // need to be calculated and can simply return 1.
  if (IsUsedForZeroCmp) {
    BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
    Builder.SetInsertPoint(ResBlock.BB, InsertPt);
    Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
    PhiRes->addIncoming(Res, ResBlock.BB);
    BranchInst *NewBr = BranchInst::Create(EndBlock);
    Builder.Insert(NewBr);
    return;
  }
  BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
  Builder.SetInsertPoint(ResBlock.BB, InsertPt);

  Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
                                  ResBlock.PhiSrc2);

  Value *Res =
      Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
                           ConstantInt::get(Builder.getInt32Ty(), 1));

  BranchInst *NewBr = BranchInst::Create(EndBlock);
  Builder.Insert(NewBr);
  PhiRes->addIncoming(Res, ResBlock.BB);
}

void MemCmpExpansion::setupResultBlockPHINodes() {
  Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
  Builder.SetInsertPoint(ResBlock.BB);
  // Note: this assumes one load per block.
  ResBlock.PhiSrc1 =
      Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
  ResBlock.PhiSrc2 =
      Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
}

void MemCmpExpansion::setupEndBlockPHINodes() {
  Builder.SetInsertPoint(&EndBlock->front());
  PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
}

Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
  unsigned LoadIndex = 0;
  // This loop populates each of the LoadCmpBlocks with the IR sequence to
  // handle multiple loads per block.
  for (unsigned I = 0; I < getNumBlocks(); ++I) {
    emitLoadCompareBlockMultipleLoads(I, LoadIndex);
  }

  emitMemCmpResultBlock();
  return PhiRes;
}

/// A memcmp expansion that compares equality with 0 and only has one block of
/// load and compare can bypass the compare, branch, and phi IR that is required
/// in the general case.
Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
  unsigned LoadIndex = 0;
  Value *Cmp = getCompareLoadPairs(0, LoadIndex);
  assert(LoadIndex == getNumLoads() && "some entries were not consumed");
  return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
}

/// A memcmp expansion that only has one block of load and compare can bypass
/// the compare, branch, and phi IR that is required in the general case.
Value *MemCmpExpansion::getMemCmpOneBlock() {
  assert(NumLoadsPerBlock == 1 && "Only handles one load pair per block");

  Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
  Value *Source1 = CI->getArgOperand(0);
  Value *Source2 = CI->getArgOperand(1);

  // Cast source to LoadSizeType*.
  if (Source1->getType() != LoadSizeType)
    Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
  if (Source2->getType() != LoadSizeType)
    Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());

  // Load LoadSizeType from the base address.
  Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
  Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);

  if (DL.isLittleEndian() && Size != 1) {
    Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
                                                Intrinsic::bswap, LoadSizeType);
    LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
    LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
  }

  if (Size < 4) {
    // The i8 and i16 cases don't need compares. We zext the loaded values and
    // subtract them to get the suitable negative, zero, or positive i32 result.
    LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
    LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
    return Builder.CreateSub(LoadSrc1, LoadSrc2);
  }

  // The result of memcmp is negative, zero, or positive, so produce that by
  // subtracting 2 extended compare bits: sub (ugt, ult).
  // If a target prefers to use selects to get -1/0/1, they should be able
  // to transform this later. The inverse transform (going from selects to math)
  // may not be possible in the DAG because the selects got converted into
  // branches before we got there.
  Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
  Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
  Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
  Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
  return Builder.CreateSub(ZextUGT, ZextULT);
}

// This function expands the memcmp call into an inline expansion and returns
// the memcmp result.
Value *MemCmpExpansion::getMemCmpExpansion() {
  // A memcmp with zero-comparison with only one block of load and compare does
  // not need to set up any extra blocks. This case could be handled in the DAG,
  // but since we have all of the machinery to flexibly expand any memcpy here,
  // we choose to handle this case too to avoid fragmented lowering.
  if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || getNumBlocks() != 1) {
    BasicBlock *StartBlock = CI->getParent();
    EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
    setupEndBlockPHINodes();
    createResultBlock();

    // If return value of memcmp is not used in a zero equality, we need to
    // calculate which source was larger. The calculation requires the
    // two loaded source values of each load compare block.
    // These will be saved in the phi nodes created by setupResultBlockPHINodes.
    if (!IsUsedForZeroCmp) setupResultBlockPHINodes();

    // Create the number of required load compare basic blocks.
    createLoadCmpBlocks();

    // Update the terminator added by splitBasicBlock to branch to the first
    // LoadCmpBlock.
    StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
  }

  Builder.SetCurrentDebugLocation(CI->getDebugLoc());

  if (IsUsedForZeroCmp)
    return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
                               : getMemCmpExpansionZeroCase();

  // TODO: Handle more than one load pair per block in getMemCmpOneBlock().
  if (getNumBlocks() == 1 && NumLoadsPerBlock == 1) return getMemCmpOneBlock();

  for (unsigned I = 0; I < getNumBlocks(); ++I) {
    emitLoadCompareBlock(I);
  }

  emitMemCmpResultBlock();
  return PhiRes;
}

// This function checks to see if an expansion of memcmp can be generated.
// It checks for constant compare size that is less than the max inline size.
// If an expansion cannot occur, returns false to leave as a library call.
// Otherwise, the library call is replaced with a new IR instruction sequence.
/// We want to transform:
/// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
/// To:
/// loadbb:
///  %0 = bitcast i32* %buffer2 to i8*
///  %1 = bitcast i32* %buffer1 to i8*
///  %2 = bitcast i8* %1 to i64*
///  %3 = bitcast i8* %0 to i64*
///  %4 = load i64, i64* %2
///  %5 = load i64, i64* %3
///  %6 = call i64 @llvm.bswap.i64(i64 %4)
///  %7 = call i64 @llvm.bswap.i64(i64 %5)
///  %8 = sub i64 %6, %7
///  %9 = icmp ne i64 %8, 0
///  br i1 %9, label %res_block, label %loadbb1
/// res_block:                                        ; preds = %loadbb2,
/// %loadbb1, %loadbb
///  %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
///  %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
///  %10 = icmp ult i64 %phi.src1, %phi.src2
///  %11 = select i1 %10, i32 -1, i32 1
///  br label %endblock
/// loadbb1:                                          ; preds = %loadbb
///  %12 = bitcast i32* %buffer2 to i8*
///  %13 = bitcast i32* %buffer1 to i8*
///  %14 = bitcast i8* %13 to i32*
///  %15 = bitcast i8* %12 to i32*
///  %16 = getelementptr i32, i32* %14, i32 2
///  %17 = getelementptr i32, i32* %15, i32 2
///  %18 = load i32, i32* %16
///  %19 = load i32, i32* %17
///  %20 = call i32 @llvm.bswap.i32(i32 %18)
///  %21 = call i32 @llvm.bswap.i32(i32 %19)
///  %22 = zext i32 %20 to i64
///  %23 = zext i32 %21 to i64
///  %24 = sub i64 %22, %23
///  %25 = icmp ne i64 %24, 0
///  br i1 %25, label %res_block, label %loadbb2
/// loadbb2:                                          ; preds = %loadbb1
///  %26 = bitcast i32* %buffer2 to i8*
///  %27 = bitcast i32* %buffer1 to i8*
///  %28 = bitcast i8* %27 to i16*
///  %29 = bitcast i8* %26 to i16*
///  %30 = getelementptr i16, i16* %28, i16 6
///  %31 = getelementptr i16, i16* %29, i16 6
///  %32 = load i16, i16* %30
///  %33 = load i16, i16* %31
///  %34 = call i16 @llvm.bswap.i16(i16 %32)
///  %35 = call i16 @llvm.bswap.i16(i16 %33)
///  %36 = zext i16 %34 to i64
///  %37 = zext i16 %35 to i64
///  %38 = sub i64 %36, %37
///  %39 = icmp ne i64 %38, 0
///  br i1 %39, label %res_block, label %loadbb3
/// loadbb3:                                          ; preds = %loadbb2
///  %40 = bitcast i32* %buffer2 to i8*
///  %41 = bitcast i32* %buffer1 to i8*
///  %42 = getelementptr i8, i8* %41, i8 14
///  %43 = getelementptr i8, i8* %40, i8 14
///  %44 = load i8, i8* %42
///  %45 = load i8, i8* %43
///  %46 = zext i8 %44 to i32
///  %47 = zext i8 %45 to i32
///  %48 = sub i32 %46, %47
///  br label %endblock
/// endblock:                                         ; preds = %res_block,
/// %loadbb3
///  %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
///  ret i32 %phi.res
static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
                         const TargetLowering *TLI, const DataLayout *DL) {
  NumMemCmpCalls++;

  // Early exit from expansion if -Oz.
  if (CI->getFunction()->optForMinSize())
    return false;

  // Early exit from expansion if size is not a constant.
  ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
  if (!SizeCast) {
    NumMemCmpNotConstant++;
    return false;
  }
  const uint64_t SizeVal = SizeCast->getZExtValue();

  if (SizeVal == 0) {
    return false;
  }

  // TTI call to check if target would like to expand memcmp. Also, get the
  // available load sizes.
  const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
  const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp);
  if (!Options) return false;

  const unsigned MaxNumLoads =
      TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize());

  MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads,
                            IsUsedForZeroCmp, MemCmpNumLoadsPerBlock, *DL);

  // Don't expand if this will require more loads than desired by the target.
  if (Expansion.getNumLoads() == 0) {
    NumMemCmpGreaterThanMax++;
    return false;
  }

  NumMemCmpInlined++;

  Value *Res = Expansion.getMemCmpExpansion();

  // Replace call with result of expansion and erase call.
  CI->replaceAllUsesWith(Res);
  CI->eraseFromParent();

  return true;
}



class ExpandMemCmpPass : public FunctionPass {
public:
  static char ID;

  ExpandMemCmpPass() : FunctionPass(ID) {
    initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
  }

  bool runOnFunction(Function &F) override {
    if (skipFunction(F)) return false;

    auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
    if (!TPC) {
      return false;
    }
    const TargetLowering* TL =
        TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();

    const TargetLibraryInfo *TLI =
        &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
    const TargetTransformInfo *TTI =
        &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
    auto PA = runImpl(F, TLI, TTI, TL);
    return !PA.areAllPreserved();
  }

private:
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<TargetLibraryInfoWrapperPass>();
    AU.addRequired<TargetTransformInfoWrapperPass>();
    FunctionPass::getAnalysisUsage(AU);
  }

  PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
                            const TargetTransformInfo *TTI,
                            const TargetLowering* TL);
  // Returns true if a change was made.
  bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
                  const TargetTransformInfo *TTI, const TargetLowering* TL,
                  const DataLayout& DL);
};

bool ExpandMemCmpPass::runOnBlock(
    BasicBlock &BB, const TargetLibraryInfo *TLI,
    const TargetTransformInfo *TTI, const TargetLowering* TL,
    const DataLayout& DL) {
  for (Instruction& I : BB) {
    CallInst *CI = dyn_cast<CallInst>(&I);
    if (!CI) {
      continue;
    }
    LibFunc Func;
    if (TLI->getLibFunc(ImmutableCallSite(CI), Func) &&
        Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) {
      return true;
    }
  }
  return false;
}


PreservedAnalyses ExpandMemCmpPass::runImpl(
    Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI,
    const TargetLowering* TL) {
  const DataLayout& DL = F.getParent()->getDataLayout();
  bool MadeChanges = false;
  for (auto BBIt = F.begin(); BBIt != F.end();) {
    if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) {
      MadeChanges = true;
      // If changes were made, restart the function from the beginning, since
      // the structure of the function was changed.
      BBIt = F.begin();
    } else {
      ++BBIt;
    }
  }
  return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all();
}

} // namespace

char ExpandMemCmpPass::ID = 0;
INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
                      "Expand memcmp() to load/stores", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
                    "Expand memcmp() to load/stores", false, false)

FunctionPass *llvm::createExpandMemCmpPass() {
  return new ExpandMemCmpPass();
}